Unverified Commit 2da403e3 authored by Alec's avatar Alec Committed by GitHub
Browse files

refactor: add explicit non-leader node handling in vLLM (#5597)


Signed-off-by: default avataralec-flowers <aflowers@nvidia.com>
parent 50f1e0e1
...@@ -148,7 +148,7 @@ jobs: ...@@ -148,7 +148,7 @@ jobs:
- { major_minor: '12.9', major: '12' } - { major_minor: '12.9', major: '12' }
name: vllm-build-test (cuda${{ matrix.cuda_version.major_minor}}, ${{ matrix.platform.arch }}) name: vllm-build-test (cuda${{ matrix.cuda_version.major_minor}}, ${{ matrix.platform.arch }})
runs-on: ${{ matrix.platform.runner }} runs-on: ${{ matrix.platform.runner }}
timeout-minutes: 90 timeout-minutes: 240
env: env:
FRAMEWORK: vllm FRAMEWORK: vllm
steps: &runtime-container-build-push-test steps: &runtime-container-build-push-test
......
...@@ -48,6 +48,19 @@ configure_dynamo_logging() ...@@ -48,6 +48,19 @@ configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def _handle_non_leader_node(dp_rank: int) -> None:
"""
Handle non-leader node (data_parallel_rank >= 1) in multi-node deployments.
Non-leader nodes run vLLM workers but don't serve Dynamo endpoints.
"""
logger.info(
f"Non-leader node detected (data_parallel_rank={dp_rank}). "
"Skipping endpoint serving."
)
# Wait indefinitely - process terminated via signal handlers
await asyncio.Event().wait()
async def graceful_shutdown(runtime): async def graceful_shutdown(runtime):
""" """
Shutdown dynamo distributed runtime. Shutdown dynamo distributed runtime.
...@@ -452,20 +465,22 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -452,20 +465,22 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
runtime.register_engine_route("wake", handler.wake) runtime.register_engine_route("wake", handler.wake)
logger.info("Registered engine routes: /engine/sleep, /engine/wake") logger.info("Registered engine routes: /engine/sleep, /engine/wake")
# Handle non-leader nodes - don't serve endpoints
if config.engine_args.data_parallel_rank:
await _handle_non_leader_node(config.engine_args.data_parallel_rank)
return
# Register prefill model with ModelType.Prefill # Register prefill model with ModelType.Prefill
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
model_input = ( await register_vllm_model(
ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens model_input,
) ModelType.Prefill,
await register_vllm_model( generate_endpoint,
model_input, config,
ModelType.Prefill, engine_client,
generate_endpoint, vllm_config,
config, migration_limit=0, # Prefill doesn't support migration
engine_client, )
vllm_config,
migration_limit=0, # Prefill doesn't support migration
)
health_check_payload = VllmPrefillHealthCheckPayload( health_check_payload = VllmPrefillHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer engine_client, use_text_input=config.use_vllm_tokenizer
...@@ -575,34 +590,34 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -575,34 +590,34 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime.register_engine_route("wake", handler.wake) runtime.register_engine_route("wake", handler.wake)
logger.info("Registered engine routes: /engine/sleep, /engine/wake") logger.info("Registered engine routes: /engine/sleep, /engine/wake")
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register # Handle non-leader nodes - don't serve endpoints
# Parse endpoint types from --dyn-endpoint-types flag if config.engine_args.data_parallel_rank:
model_type = parse_endpoint_types(config.dyn_endpoint_types) await _handle_non_leader_node(config.engine_args.data_parallel_rank)
logger.info( return
f"Registering model with endpoint types: {config.dyn_endpoint_types}"
)
model_input = ( # Parse endpoint types from --dyn-endpoint-types flag
ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens model_type = parse_endpoint_types(config.dyn_endpoint_types)
) logger.info(f"Registering model with endpoint types: {config.dyn_endpoint_types}")
# Warn if custom template provided but chat endpoint not enabled model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types:
logger.warning(
"Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
"The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
)
await register_vllm_model( # Warn if custom template provided but chat endpoint not enabled
model_input, if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types:
model_type, logger.warning(
generate_endpoint, "Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
config, "The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
engine_client,
vllm_config,
migration_limit=config.migration_limit,
) )
await register_vllm_model(
model_input,
model_type,
generate_endpoint,
config,
engine_client,
vllm_config,
migration_limit=config.migration_limit,
)
health_check_payload = VllmHealthCheckPayload( health_check_payload = VllmHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict() ).to_dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment