Unverified Commit 704c1dad authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

fix: fix vllm graceful shutdown (#5818)

parent 284f772b
...@@ -25,7 +25,12 @@ class VllmEngineMonitor: ...@@ -25,7 +25,12 @@ class VllmEngineMonitor:
Monitors the health of the vLLM engine and initiates a shutdown if the engine is dead. Monitors the health of the vLLM engine and initiates a shutdown if the engine is dead.
""" """
def __init__(self, runtime: DistributedRuntime, engine_client: AsyncLLM): def __init__(
self,
runtime: DistributedRuntime,
engine_client: AsyncLLM,
shutdown_event: asyncio.Event = None,
):
if not isinstance(runtime, DistributedRuntime): if not isinstance(runtime, DistributedRuntime):
raise ValueError( raise ValueError(
f"{self.__class__.__name__} requires an instance of DistributedRuntime." f"{self.__class__.__name__} requires an instance of DistributedRuntime."
...@@ -37,6 +42,7 @@ class VllmEngineMonitor: ...@@ -37,6 +42,7 @@ class VllmEngineMonitor:
self.runtime = runtime self.runtime = runtime
self.engine_client = engine_client self.engine_client = engine_client
self.shutdown_event = shutdown_event
self._monitor_task = asyncio.create_task(self._check_engine_health()) self._monitor_task = asyncio.create_task(self._check_engine_health())
logger.info( logger.info(
...@@ -66,10 +72,41 @@ class VllmEngineMonitor: ...@@ -66,10 +72,41 @@ class VllmEngineMonitor:
signal.alarm(0) signal.alarm(0)
async def _check_engine_health(self): async def _check_engine_health(self):
"""
Continuously check engine health until:
1. Engine dies (EngineDeadError) - initiate shutdown
2. Shutdown event is triggered - stop monitoring gracefully
3. Task is cancelled - cleanup
"""
while True: while True:
try: try:
# Check if shutdown event was triggered - stop monitoring
if self.shutdown_event and self.shutdown_event.is_set():
logger.info(
f"{self.__class__.__name__}: Shutdown event detected, stopping engine health monitoring."
)
break
await self.engine_client.check_health() await self.engine_client.check_health()
await asyncio.sleep(HEALTH_CHECK_INTERVAL)
# Sleep with shutdown event awareness for faster response
if self.shutdown_event:
try:
await asyncio.wait_for(
self.shutdown_event.wait(), timeout=HEALTH_CHECK_INTERVAL
)
# Shutdown event was set during sleep
logger.info(
f"{self.__class__.__name__}: Shutdown event detected, stopping engine health monitoring."
)
break
except asyncio.TimeoutError:
# Normal timeout, continue monitoring
pass
else:
# No shutdown event, just sleep normally
await asyncio.sleep(HEALTH_CHECK_INTERVAL)
except EngineDeadError as e: except EngineDeadError as e:
logger.error(f"Traceback: {traceback.format_exc()}") logger.error(f"Traceback: {traceback.format_exc()}")
logger.error(f"vLLM AsyncLLM health check failed: {e}") logger.error(f"vLLM AsyncLLM health check failed: {e}")
...@@ -78,4 +115,5 @@ class VllmEngineMonitor: ...@@ -78,4 +115,5 @@ class VllmEngineMonitor:
self.runtime.shutdown() self.runtime.shutdown()
os._exit(1) os._exit(1)
except asyncio.CancelledError: except asyncio.CancelledError:
pass logger.debug(f"{self.__class__.__name__}: Health check task cancelled.")
break
...@@ -243,6 +243,7 @@ class BaseWorkerHandler(ABC): ...@@ -243,6 +243,7 @@ class BaseWorkerHandler(ABC):
generate_endpoint=None, generate_endpoint=None,
config=None, config=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None,
): ):
self.runtime = runtime self.runtime = runtime
self.component = component self.component = component
...@@ -251,7 +252,7 @@ class BaseWorkerHandler(ABC): ...@@ -251,7 +252,7 @@ class BaseWorkerHandler(ABC):
self.kv_publishers: list[ZmqKvEventPublisher] | None = None self.kv_publishers: list[ZmqKvEventPublisher] | None = None
self.generate_endpoint = generate_endpoint self.generate_endpoint = generate_endpoint
self.config = config self.config = config
self.engine_monitor = VllmEngineMonitor(runtime, engine) self.engine_monitor = VllmEngineMonitor(runtime, engine, shutdown_event)
self.image_loader = ImageLoader() self.image_loader = ImageLoader()
self.temp_dirs: list[tempfile.TemporaryDirectory] = [] self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len self.model_max_len = model_max_len
...@@ -272,6 +273,9 @@ class BaseWorkerHandler(ABC): ...@@ -272,6 +273,9 @@ class BaseWorkerHandler(ABC):
tokenizer = engine.tokenizer tokenizer = engine.tokenizer
self.input_param_manager = InputParamManager(tokenizer) self.input_param_manager = InputParamManager(tokenizer)
# Store shutdown event for graceful shutdown monitoring
self.shutdown_event = shutdown_event
async def sleep(self, body: dict) -> dict: async def sleep(self, body: dict) -> dict:
"""Sleep the engine to release GPU memory and unregister from discovery. """Sleep the engine to release GPU memory and unregister from discovery.
...@@ -339,14 +343,44 @@ class BaseWorkerHandler(ABC): ...@@ -339,14 +343,44 @@ class BaseWorkerHandler(ABC):
raise NotImplementedError raise NotImplementedError
async def _monitor_abort(self, context, request_id, is_prefill): async def _monitor_abort(self, context, request_id, is_prefill):
"""Background task that monitors for context cancellation and aborts the request.""" """
Background task that monitors for context cancellation and shutdown.
Aborts the request if either occurs. Raises GeneratorExit if shutdown was triggered.
"""
try: try:
await context.async_killed_or_stopped() # Build list of futures/tasks to wait for
# If we reach here, the context was stopped or killed wait_for = [context.async_killed_or_stopped()]
shutdown_task = None
if self.shutdown_event:
# Create task for shutdown monitoring and add to wait list
shutdown_task = asyncio.create_task(self.shutdown_event.wait())
wait_for.append(shutdown_task)
# Wait for whichever happens first
done, pending = await asyncio.wait(
wait_for,
return_when=asyncio.FIRST_COMPLETED,
)
# Cancel the pending task/future
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Abort the request
await self.engine_client.abort(request_id) await self.engine_client.abort(request_id)
logger.debug( logger.debug(
f"Aborted {'Prefill ' if is_prefill else ''}Request ID: {request_id}" f"Aborted {'Prefill ' if is_prefill else ''}Request ID: {request_id}"
) )
# Check which event triggered and raise GeneratorExit if shutdown
if shutdown_task and shutdown_task in done:
raise GeneratorExit("Engine was shut down during generation.")
except asyncio.CancelledError: except asyncio.CancelledError:
# Task was cancelled, normal cleanup if not aborted # Task was cancelled, normal cleanup if not aborted
pass pass
...@@ -355,18 +389,24 @@ class BaseWorkerHandler(ABC): ...@@ -355,18 +389,24 @@ class BaseWorkerHandler(ABC):
@asynccontextmanager @asynccontextmanager
async def _abort_monitor(self, context, request_id, is_prefill=False): async def _abort_monitor(self, context, request_id, is_prefill=False):
"""Context manager that creates and automatically cleans up an abort monitoring task.""" """
Context manager that creates and automatically cleans up an abort monitoring task.
If shutdown event was triggered, raises GeneratorExit on exit.
"""
task = asyncio.create_task(self._monitor_abort(context, request_id, is_prefill)) task = asyncio.create_task(self._monitor_abort(context, request_id, is_prefill))
try: try:
yield task yield task
finally: finally:
# Cancel the abort monitoring task when exiting the context # Clean up the abort monitoring task
if not task.done(): if not task.done():
task.cancel() task.cancel()
try: try:
await task await task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
else:
# If the task completed, check if it raised GeneratorExit
task.result()
async def clear_kv_blocks(self, request=None): async def clear_kv_blocks(self, request=None):
try: try:
...@@ -389,6 +429,20 @@ class BaseWorkerHandler(ABC): ...@@ -389,6 +429,20 @@ class BaseWorkerHandler(ABC):
self._lora_load_locks[lora_name] = lock self._lora_load_locks[lora_name] = lock
return lock return lock
def _normalize_finish_reason(self, finish_reason: str) -> str:
"""
Normalize vLLM finish reasons to Dynamo-compatible values.
vLLM may return finish reasons that aren't recognized by Dynamo's Rust layer.
This method maps them to compatible values.
[TODO]: Remove this method and add the right code in the Rust layer.
"""
# Map vLLM's "abort" to Dynamo's "cancelled"
if finish_reason.startswith("abort"):
logging.debug(f"Normalizing finish reason: {finish_reason} to cancelled")
return "cancelled"
return finish_reason
async def load_lora(self, request=None): async def load_lora(self, request=None):
""" """
Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine. Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine.
...@@ -1112,63 +1166,57 @@ class BaseWorkerHandler(ABC): ...@@ -1112,63 +1166,57 @@ class BaseWorkerHandler(ABC):
) )
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
try: async for res in gen:
async for res in gen: # res is vllm's RequestOutput
# res is vllm's RequestOutput
if not res.outputs: if not res.outputs:
self._log_with_lora_context( self._log_with_lora_context(
"Request {request_id}{lora_info} returned no outputs", "Request {request_id}{lora_info} returned no outputs",
request_id, request_id,
lora_request, lora_request,
) )
# Use string format "error: message" for consistency with vLLM's string-based finish_reason # Use string format "error: message" for consistency with vLLM's string-based finish_reason
# Rust will parse this into FinishReason::Error(message) # Rust will parse this into FinishReason::Error(message)
yield { yield {
"finish_reason": "error: No outputs from vLLM engine", "finish_reason": "error: No outputs from vLLM engine",
"token_ids": [], "token_ids": [],
} }
break break
output = res.outputs[0] output = res.outputs[0]
next_total_toks = len(output.token_ids) next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
# Extract logprobs for new tokens if available # Extract logprobs for new tokens if available
log_probs, top_logprobs = self._extract_logprobs( log_probs, top_logprobs = self._extract_logprobs(
output, num_output_tokens_so_far output, num_output_tokens_so_far
)
if log_probs is not None:
out["log_probs"] = log_probs
if top_logprobs is not None:
out["top_logprobs"] = top_logprobs
if output.finish_reason:
out["finish_reason"] = self._normalize_finish_reason(
output.finish_reason
) )
if log_probs is not None: out["completion_usage"] = BaseWorkerHandler._build_completion_usage(
out["log_probs"] = log_probs request_output=res,
if top_logprobs is not None: embedding_sequence_length=embedding_sequence_length,
out["top_logprobs"] = top_logprobs )
# Log completion with LoRA info (debug level to avoid log spam)
if output.finish_reason: self._log_with_lora_context(
out["finish_reason"] = output.finish_reason "Completed token generation for request {request_id}{lora_info}: "
out[ "{output_tokens} output tokens, finish_reason={finish_reason}",
"completion_usage" request_id,
] = BaseWorkerHandler._build_completion_usage( lora_request,
request_output=res, output_tokens=next_total_toks,
embedding_sequence_length=embedding_sequence_length, finish_reason=output.finish_reason,
) )
# Log completion with LoRA info (debug level to avoid log spam) if output.stop_reason:
self._log_with_lora_context( out["stop_reason"] = output.stop_reason
"Completed token generation for request {request_id}{lora_info}: " yield out
"{output_tokens} output tokens, finish_reason={finish_reason}", num_output_tokens_so_far = next_total_toks
request_id,
lora_request,
output_tokens=next_total_toks,
finish_reason=output.finish_reason,
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except asyncio.CancelledError:
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
raise GeneratorExit(
"Decode engine was shut down during token generation"
) from None
except EngineDeadError as e: except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}") logger.error(f"vLLM EngineDeadError: {e}")
...@@ -1189,6 +1237,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1189,6 +1237,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
generate_endpoint=None, generate_endpoint=None,
config=None, config=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None,
): ):
super().__init__( super().__init__(
runtime, runtime,
...@@ -1200,6 +1249,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1200,6 +1249,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
generate_endpoint, generate_endpoint,
config, config,
use_vllm_tokenizer, use_vllm_tokenizer,
shutdown_event,
) )
async def generate(self, request, context): async def generate(self, request, context):
...@@ -1361,7 +1411,9 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1361,7 +1411,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"role": "assistant", "role": "assistant",
"content": delta_text, "content": delta_text,
}, },
"finish_reason": output.finish_reason, "finish_reason": self._normalize_finish_reason(
output.finish_reason
),
} }
chunk = { chunk = {
...@@ -1398,6 +1450,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1398,6 +1450,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
generate_endpoint=None, generate_endpoint=None,
config=None, config=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None,
): ):
super().__init__( super().__init__(
runtime, runtime,
...@@ -1409,6 +1462,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1409,6 +1462,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
generate_endpoint, generate_endpoint,
config, config,
use_vllm_tokenizer, use_vllm_tokenizer,
shutdown_event,
) )
async def generate(self, request, context): async def generate(self, request, context):
...@@ -1501,39 +1555,33 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1501,39 +1555,33 @@ class PrefillWorkerHandler(BaseWorkerHandler):
self.runtime.shutdown() self.runtime.shutdown()
os._exit(1) os._exit(1)
try: async for res in gen:
async for res in gen: logger.debug(f"kv transfer params: {res.kv_transfer_params}")
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
token_ids = res.outputs[0].token_ids if res.outputs else []
token_ids = res.outputs[0].token_ids if res.outputs else []
output: Dict[str, Any] = {
output: Dict[str, Any] = { "token_ids": list(token_ids),
"token_ids": list(token_ids), "disaggregated_params": (
"disaggregated_params": ( {"kv_transfer_params": res.kv_transfer_params}
{"kv_transfer_params": res.kv_transfer_params} if res.kv_transfer_params
if res.kv_transfer_params else None
else None ),
), "completion_usage": BaseWorkerHandler._build_completion_usage(
"completion_usage": BaseWorkerHandler._build_completion_usage( request_output=res,
request_output=res, embedding_sequence_length=embedding_sequence_length,
embedding_sequence_length=embedding_sequence_length, ),
), }
}
# Log prefill completion with LoRA info # Log prefill completion with LoRA info
self._log_with_lora_context( self._log_with_lora_context(
"Prefill completed for request {request_id}{lora_info}: " "Prefill completed for request {request_id}{lora_info}: "
"generated {token_count} token(s), has_kv_params={has_kv_params}", "generated {token_count} token(s), has_kv_params={has_kv_params}",
request_id, request_id,
lora_request, lora_request,
level="info" if lora_request else "debug", level="info" if lora_request else "debug",
token_count=len(token_ids), token_count=len(token_ids),
has_kv_params=res.kv_transfer_params is not None, has_kv_params=res.kv_transfer_params is not None,
) )
yield output yield output
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
...@@ -61,7 +61,7 @@ async def _handle_non_leader_node(dp_rank: int) -> None: ...@@ -61,7 +61,7 @@ async def _handle_non_leader_node(dp_rank: int) -> None:
await asyncio.Event().wait() await asyncio.Event().wait()
async def graceful_shutdown(runtime): async def graceful_shutdown(runtime, shutdown_event):
""" """
Shutdown dynamo distributed runtime. Shutdown dynamo distributed runtime.
The endpoints will be immediately invalidated so no new requests will be accepted. The endpoints will be immediately invalidated so no new requests will be accepted.
...@@ -69,6 +69,7 @@ async def graceful_shutdown(runtime): ...@@ -69,6 +69,7 @@ async def graceful_shutdown(runtime):
For endpoints served with graceful_shutdown=False, the serving function will return immediately. For endpoints served with graceful_shutdown=False, the serving function will return immediately.
""" """
logging.info("Received shutdown signal, shutting down DistributedRuntime") logging.info("Received shutdown signal, shutting down DistributedRuntime")
shutdown_event.set()
runtime.shutdown() runtime.shutdown()
logging.info("DistributedRuntime shutdown complete") logging.info("DistributedRuntime shutdown complete")
...@@ -79,6 +80,9 @@ async def worker(): ...@@ -79,6 +80,9 @@ async def worker():
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
overwrite_args(config) overwrite_args(config)
# Create shutdown event
shutdown_event = asyncio.Event()
# Set DYN_EVENT_PLANE environment variable based on config # Set DYN_EVENT_PLANE environment variable based on config
os.environ["DYN_EVENT_PLANE"] = config.event_plane os.environ["DYN_EVENT_PLANE"] = config.event_plane
...@@ -95,7 +99,7 @@ async def worker(): ...@@ -95,7 +99,7 @@ async def worker():
# Set up signal handler for graceful shutdown # Set up signal handler for graceful shutdown
def signal_handler(): def signal_handler():
asyncio.create_task(graceful_shutdown(runtime)) asyncio.create_task(graceful_shutdown(runtime, shutdown_event))
for sig in (signal.SIGTERM, signal.SIGINT): for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler) loop.add_signal_handler(sig, signal_handler)
...@@ -123,29 +127,29 @@ async def worker(): ...@@ -123,29 +127,29 @@ async def worker():
# Route to appropriate initialization based on config flags # Route to appropriate initialization based on config flags
if config.vllm_native_encoder_worker: if config.vllm_native_encoder_worker:
await init_vllm_native_encoder(runtime, config) await init_vllm_native_encoder(runtime, config, shutdown_event)
logger.debug("init_vllm_native_encoder completed") logger.debug("init_vllm_native_encoder completed")
elif config.ec_processor: elif config.ec_processor:
await init_ec_processor(runtime, config) await init_ec_processor(runtime, config, shutdown_event)
logger.debug("init_ec_processor completed") logger.debug("init_ec_processor completed")
elif config.multimodal_processor: elif config.multimodal_processor:
await init_multimodal_processor(runtime, config) await init_multimodal_processor(runtime, config, shutdown_event)
logger.debug("init_multimodal_processor completed") logger.debug("init_multimodal_processor completed")
elif config.multimodal_encode_worker: elif config.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config) await init_multimodal_encode_worker(runtime, config, shutdown_event)
logger.debug("init_multimodal_encode_worker completed") logger.debug("init_multimodal_encode_worker completed")
elif ( elif (
config.multimodal_worker config.multimodal_worker
or config.multimodal_decode_worker or config.multimodal_decode_worker
or config.multimodal_encode_prefill_worker or config.multimodal_encode_prefill_worker
): ):
await init_multimodal_worker(runtime, config) await init_multimodal_worker(runtime, config, shutdown_event)
logger.debug("init_multimodal_worker completed") logger.debug("init_multimodal_worker completed")
elif config.is_prefill_worker: elif config.is_prefill_worker:
await init_prefill(runtime, config) await init_prefill(runtime, config, shutdown_event)
logger.debug("init_prefill completed") logger.debug("init_prefill completed")
else: else:
await init(runtime, config) await init(runtime, config, shutdown_event)
logger.debug("init completed") logger.debug("init completed")
logger.debug("Worker function completed, exiting...") logger.debug("Worker function completed, exiting...")
...@@ -415,7 +419,9 @@ async def register_vllm_model( ...@@ -415,7 +419,9 @@ async def register_vllm_model(
) )
async def init_prefill(runtime: DistributedRuntime, config: Config): async def init_prefill(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
""" """
Instantiate and serve Instantiate and serve
""" """
...@@ -441,6 +447,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -441,6 +447,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint=generate_endpoint, generate_endpoint=generate_endpoint,
config=config, config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer, use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
...@@ -527,7 +534,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -527,7 +534,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
handler.cleanup() handler.cleanup()
async def init(runtime: DistributedRuntime, config: Config): async def init(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
""" """
Instantiate and serve Instantiate and serve
""" """
...@@ -566,6 +575,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -566,6 +575,7 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint=generate_endpoint, generate_endpoint=generate_endpoint,
config=config, config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer, use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
...@@ -699,7 +709,9 @@ def get_engine_cache_info(engine: AsyncLLM): ...@@ -699,7 +709,9 @@ def get_engine_cache_info(engine: AsyncLLM):
raise raise
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config): async def init_multimodal_processor(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize multimodal processor component""" """Initialize multimodal processor component"""
component = runtime.namespace(config.namespace).component(config.component) component = runtime.namespace(config.namespace).component(config.component)
...@@ -754,7 +766,9 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config) ...@@ -754,7 +766,9 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
handler.cleanup() handler.cleanup()
async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Config): async def init_multimodal_encode_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize multimodal encode worker component""" """Initialize multimodal encode worker component"""
component = runtime.namespace(config.namespace).component(config.component) component = runtime.namespace(config.namespace).component(config.component)
...@@ -792,7 +806,9 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con ...@@ -792,7 +806,9 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
handler.cleanup() handler.cleanup()
async def init_vllm_native_encoder(runtime: DistributedRuntime, config: Config): async def init_vllm_native_encoder(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
""" """
Initialize vLLM-native encoder worker component (ECConnector mode). Initialize vLLM-native encoder worker component (ECConnector mode).
In this mode, vLLM handles encoder execution, caching, and storage automatically. In this mode, vLLM handles encoder execution, caching, and storage automatically.
...@@ -853,7 +869,9 @@ async def init_vllm_native_encoder(runtime: DistributedRuntime, config: Config): ...@@ -853,7 +869,9 @@ async def init_vllm_native_encoder(runtime: DistributedRuntime, config: Config):
handler.cleanup() handler.cleanup()
async def init_ec_processor(runtime: DistributedRuntime, config: Config): async def init_ec_processor(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
""" """
Initialize ECConnector processor component. Initialize ECConnector processor component.
...@@ -923,7 +941,9 @@ async def init_ec_processor(runtime: DistributedRuntime, config: Config): ...@@ -923,7 +941,9 @@ async def init_ec_processor(runtime: DistributedRuntime, config: Config):
handler.cleanup() handler.cleanup()
async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): async def init_multimodal_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
""" """
Initialize multimodal worker component. Initialize multimodal worker component.
...@@ -983,11 +1003,16 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): ...@@ -983,11 +1003,16 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
# Choose handler based on worker type # Choose handler based on worker type
if config.multimodal_decode_worker: if config.multimodal_decode_worker:
handler = MultimodalDecodeWorkerHandler( handler = MultimodalDecodeWorkerHandler(
runtime, component, engine_client, config runtime, component, engine_client, config, shutdown_event
) )
else: else:
handler = MultimodalPDWorkerHandler( handler = MultimodalPDWorkerHandler(
runtime, component, engine_client, config, decode_worker_client runtime,
component,
engine_client,
config,
decode_worker_client,
shutdown_event,
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
......
...@@ -37,6 +37,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -37,6 +37,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
component, component,
engine_client, engine_client,
config, config,
shutdown_event=None,
): ):
# Get default_sampling_params from config # Get default_sampling_params from config
default_sampling_params = ( default_sampling_params = (
...@@ -50,6 +51,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -50,6 +51,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
engine_client, engine_client,
default_sampling_params, default_sampling_params,
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
shutdown_event=shutdown_event,
) )
self.config = config self.config = config
...@@ -117,6 +119,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -117,6 +119,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
engine_client: AsyncLLM, engine_client: AsyncLLM,
config, config,
decode_worker_client: Client = None, decode_worker_client: Client = None,
shutdown_event=None,
): ):
# Get default_sampling_params from config # Get default_sampling_params from config
default_sampling_params = ( default_sampling_params = (
...@@ -130,6 +133,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -130,6 +133,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
engine_client, engine_client,
default_sampling_params, default_sampling_params,
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
shutdown_event=shutdown_event,
) )
self.config = config self.config = config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment