Unverified Commit 4447c247 authored by davilu-nvidia's avatar davilu-nvidia Committed by GitHub
Browse files

fix: resolve sgl be E/P/D multimodal routing issues (#5500)


Signed-off-by: default avatardavilu <davilu@nvidia.com>
parent 417ce216
...@@ -487,23 +487,13 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con ...@@ -487,23 +487,13 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
await pd_worker_client.wait_for_instances() await pd_worker_client.wait_for_instances()
ready_event = asyncio.Event()
try: try:
await asyncio.gather( # Encode Worker is an internal component, should not register with Frontend
generate_endpoint.serve_endpoint( # Only needs to provide internal service endpoint for Processor to call
handler.generate, await generate_endpoint.serve_endpoint(
graceful_shutdown=True, handler.generate,
metrics_labels=[("model", server_args.served_model_name)], graceful_shutdown=True,
), metrics_labels=[("model", server_args.served_model_name)],
register_llm_with_readiness_gate(
None, # encode worker doesn't have engine
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Text,
readiness_gate=ready_event,
),
) )
except Exception as e: except Exception as e:
logging.error(f"Failed to serve endpoints: {e}") logging.error(f"Failed to serve endpoints: {e}")
...@@ -542,21 +532,32 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): ...@@ -542,21 +532,32 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
ready_event = asyncio.Event() ready_event = asyncio.Event()
try: try:
await asyncio.gather( if config.serving_mode == DisaggregationMode.DECODE:
generate_endpoint.serve_endpoint( # Decode Worker is an internal component, should not register with Frontend
# Only needs to provide internal service endpoint for Processor to call
await generate_endpoint.serve_endpoint(
handler.generate, handler.generate,
metrics_labels=[("model", server_args.served_model_name)], metrics_labels=[("model", server_args.served_model_name)],
graceful_shutdown=True, graceful_shutdown=True,
health_check_payload=health_check_payload, health_check_payload=health_check_payload,
), )
register_llm_with_readiness_gate( else:
engine, # In aggregated mode, need to register with Frontend
generate_endpoint, await asyncio.gather(
server_args, generate_endpoint.serve_endpoint(
dynamo_args, handler.generate,
readiness_gate=ready_event, metrics_labels=[("model", server_args.served_model_name)],
), graceful_shutdown=True,
) health_check_payload=health_check_payload,
),
register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
readiness_gate=ready_event,
),
)
except Exception as e: except Exception as e:
logging.error(f"Failed to serve endpoints: {e}") logging.error(f"Failed to serve endpoints: {e}")
raise raise
...@@ -580,23 +581,15 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co ...@@ -580,23 +581,15 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co
await handler.async_init() await handler.async_init()
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
ready_event = asyncio.Event()
try: try:
await asyncio.gather( # Prefill Worker is an internal component, should not register with Frontend
generate_endpoint.serve_endpoint( # Only needs to provide internal service endpoint for Decode Worker to call
handler.generate, await generate_endpoint.serve_endpoint(
graceful_shutdown=True, handler.generate,
metrics_labels=[("model", server_args.served_model_name)], graceful_shutdown=True,
health_check_payload=health_check_payload, metrics_labels=[("model", server_args.served_model_name)],
), health_check_payload=health_check_payload,
register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
readiness_gate=ready_event,
),
) )
except Exception as e: except Exception as e:
logging.error(f"Failed to serve endpoints: {e}") logging.error(f"Failed to serve endpoints: {e}")
......
...@@ -82,7 +82,18 @@ def is_model_supported(model_name: str, supported_model: str) -> bool: ...@@ -82,7 +82,18 @@ def is_model_supported(model_name: str, supported_model: str) -> bool:
normalized_name = normalize_model_name(model_name).lower() normalized_name = normalize_model_name(model_name).lower()
normalized_supported = normalize_model_name(supported_model).lower() normalized_supported = normalize_model_name(supported_model).lower()
return normalized_name == normalized_supported # Exact match
if normalized_name == normalized_supported:
return True
# Handle local path case: compare only the model name part (without organization)
# e.g., "qwen2.5-vl-7b-instruct" matches "qwen/qwen2.5-vl-7b-instruct"
if "/" in normalized_supported:
model_part = normalized_supported.split("/")[-1]
if normalized_name == model_part:
return True
return False
def get_qwen_image_features( def get_qwen_image_features(
......
...@@ -237,7 +237,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -237,7 +237,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
config: Config, config: Config,
prefill_client: Client = None, prefill_client: Client = None,
): ):
super().__init__(component, engine, config, None, prefill_client) super().__init__(component, engine, config, None)
# Initialize processors # Initialize processors
self.embeddings_processor = EmbeddingsProcessor() self.embeddings_processor = EmbeddingsProcessor()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment