Unverified Commit 8c2072cf authored by Alec's avatar Alec Committed by GitHub
Browse files

fix: [trtllm] add wait_for_instance before register_llm (#2683)


Signed-off-by: default avataralec-flowers <aflowers@nvidia.com>
parent 63f5bbc0
...@@ -239,17 +239,6 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -239,17 +239,6 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.reasoning_parser = config.reasoning_parser runtime_config.reasoning_parser = config.reasoning_parser
runtime_config.tool_call_parser = config.tool_call_parser runtime_config.tool_call_parser = config.tool_call_parser
if is_first_worker(config):
# Register the model with runtime config
await register_llm(
modelType,
endpoint,
config.model_path,
config.served_model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
runtime_config=runtime_config,
)
# publisher will be set later if publishing is enabled. # publisher will be set later if publishing is enabled.
handler_config = RequestHandlerConfig( handler_config = RequestHandlerConfig(
component=component, component=component,
...@@ -262,6 +251,23 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -262,6 +251,23 @@ async def init(runtime: DistributedRuntime, config: Config):
multimodal_processor=multimodal_processor, multimodal_processor=multimodal_processor,
) )
if next_client:
logging.info(
f"Waiting for the next endpoint to be ready: {config.next_endpoint}"
)
await next_client.wait_for_instances()
if is_first_worker(config):
# Register the model with runtime config
await register_llm(
modelType,
endpoint,
config.model_path,
config.served_model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
)
if config.publish_events_and_metrics and is_first_worker(config): if config.publish_events_and_metrics and is_first_worker(config):
# Initialize and pass in the publisher to the request handler to # Initialize and pass in the publisher to the request handler to
# publish events and metrics. # publish events and metrics.
......
...@@ -23,7 +23,7 @@ class EngineConfig: ...@@ -23,7 +23,7 @@ class EngineConfig:
endpoints: List[str] endpoints: List[str]
response_handlers: List[Callable[[Any], str]] response_handlers: List[Callable[[Any], str]]
model: str model: str
timeout: int = 120 timeout: int = 600
delayed_start: int = 0 delayed_start: int = 0
......
...@@ -22,8 +22,6 @@ logger = logging.getLogger(__name__) ...@@ -22,8 +22,6 @@ logger = logging.getLogger(__name__)
class TRTLLMConfig(EngineConfig): class TRTLLMConfig(EngineConfig):
"""Configuration for trtllm test scenarios""" """Configuration for trtllm test scenarios"""
timeout: int = 60
class TRTLLMProcess(EngineProcess): class TRTLLMProcess(EngineProcess):
"""Simple process manager for trtllm shell scripts""" """Simple process manager for trtllm shell scripts"""
...@@ -71,9 +69,7 @@ trtllm_configs = { ...@@ -71,9 +69,7 @@ trtllm_configs = {
chat_completions_response_handler, chat_completions_response_handler,
completions_response_handler, completions_response_handler,
], ],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", model="Qwen/Qwen3-0.6B",
delayed_start=0,
timeout=360,
), ),
"disaggregated": TRTLLMConfig( "disaggregated": TRTLLMConfig(
name="disaggregated", name="disaggregated",
...@@ -85,9 +81,7 @@ trtllm_configs = { ...@@ -85,9 +81,7 @@ trtllm_configs = {
chat_completions_response_handler, chat_completions_response_handler,
completions_response_handler, completions_response_handler,
], ],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", model="Qwen/Qwen3-0.6B",
delayed_start=0,
timeout=360,
), ),
# TODO: These are sanity tests that the kv router examples launch # TODO: These are sanity tests that the kv router examples launch
# and inference without error, but do not do detailed checks on the # and inference without error, but do not do detailed checks on the
...@@ -102,9 +96,7 @@ trtllm_configs = { ...@@ -102,9 +96,7 @@ trtllm_configs = {
chat_completions_response_handler, chat_completions_response_handler,
completions_response_handler, completions_response_handler,
], ],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", model="Qwen/Qwen3-0.6B",
delayed_start=0,
timeout=360,
), ),
"disaggregated_router": TRTLLMConfig( "disaggregated_router": TRTLLMConfig(
name="disaggregated_router", name="disaggregated_router",
...@@ -116,9 +108,7 @@ trtllm_configs = { ...@@ -116,9 +108,7 @@ trtllm_configs = {
chat_completions_response_handler, chat_completions_response_handler,
completions_response_handler, completions_response_handler,
], ],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", model="Qwen/Qwen3-0.6B",
delayed_start=0,
timeout=360,
), ),
} }
......
...@@ -133,8 +133,6 @@ vllm_configs = { ...@@ -133,8 +133,6 @@ vllm_configs = {
completions_response_handler, completions_response_handler,
], ],
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
delayed_start=0,
timeout=360,
), ),
"agg-router": VLLMConfig( "agg-router": VLLMConfig(
name="agg-router", name="agg-router",
...@@ -147,8 +145,6 @@ vllm_configs = { ...@@ -147,8 +145,6 @@ vllm_configs = {
completions_response_handler, completions_response_handler,
], ],
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
delayed_start=0,
timeout=360,
), ),
"disaggregated": VLLMConfig( "disaggregated": VLLMConfig(
name="disaggregated", name="disaggregated",
...@@ -161,8 +157,6 @@ vllm_configs = { ...@@ -161,8 +157,6 @@ vllm_configs = {
completions_response_handler, completions_response_handler,
], ],
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
delayed_start=0,
timeout=360,
), ),
"deepep": VLLMConfig( "deepep": VLLMConfig(
name="deepep", name="deepep",
...@@ -179,7 +173,6 @@ vllm_configs = { ...@@ -179,7 +173,6 @@ vllm_configs = {
completions_response_handler, completions_response_handler,
], ],
model="deepseek-ai/DeepSeek-V2-Lite", model="deepseek-ai/DeepSeek-V2-Lite",
delayed_start=0,
args=[ args=[
"--model", "--model",
"deepseek-ai/DeepSeek-V2-Lite", "deepseek-ai/DeepSeek-V2-Lite",
...@@ -190,7 +183,7 @@ vllm_configs = { ...@@ -190,7 +183,7 @@ vllm_configs = {
"--gpus-per-node", "--gpus-per-node",
"2", "2",
], ],
timeout=560, timeout=700,
), ),
"multimodal_agg_llava": VLLMConfig( "multimodal_agg_llava": VLLMConfig(
name="multimodal_agg_llava", name="multimodal_agg_llava",
...@@ -202,9 +195,7 @@ vllm_configs = { ...@@ -202,9 +195,7 @@ vllm_configs = {
chat_completions_response_handler, chat_completions_response_handler,
], ],
model="llava-hf/llava-1.5-7b-hf", model="llava-hf/llava-1.5-7b-hf",
delayed_start=0,
args=["--model", "llava-hf/llava-1.5-7b-hf"], args=["--model", "llava-hf/llava-1.5-7b-hf"],
timeout=360,
), ),
"multimodal_agg_qwen": VLLMConfig( "multimodal_agg_qwen": VLLMConfig(
name="multimodal_agg_qwen", name="multimodal_agg_qwen",
......
...@@ -17,6 +17,7 @@ import json ...@@ -17,6 +17,7 @@ import json
import logging import logging
import os import os
import shutil import shutil
import signal
import socket import socket
import subprocess import subprocess
import time import time
...@@ -82,6 +83,10 @@ class ManagedProcess: ...@@ -82,6 +83,10 @@ class ManagedProcess:
straggler_commands: List[str] = field(default_factory=list) straggler_commands: List[str] = field(default_factory=list)
log_dir: str = os.getcwd() log_dir: str = os.getcwd()
# Ensure attributes exist even if startup fails early
proc: Optional[subprocess.Popen] = None
_pgid: Optional[int] = None
_logger = logging.getLogger() _logger = logging.getLogger()
_command_name = None _command_name = None
_log_path = None _log_path = None
...@@ -107,20 +112,30 @@ class ManagedProcess: ...@@ -107,20 +112,30 @@ class ManagedProcess:
return self return self
except Exception as e: except Exception:
self.__exit__(None, None, None) try:
raise e self.__exit__(None, None, None)
except Exception as cleanup_err:
self._logger.warning(
"Error during cleanup in __enter__: %s", cleanup_err
)
raise
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self._terminate_process_group()
process_list = [self.proc, self._tee_proc, self._sed_proc] process_list = [self.proc, self._tee_proc, self._sed_proc]
for process in process_list: for process in process_list:
if process: if process:
if process.stdout: try:
process.stdout.close() if process.stdout:
if process.stdin: process.stdout.close()
process.stdin.close() if process.stdin:
terminate_process_tree(process.pid, self._logger) process.stdin.close()
process.wait() terminate_process_tree(process.pid, self._logger)
process.wait()
except Exception as e:
self._logger.warning("Error terminating process: %s", e)
if self.data_dir: if self.data_dir:
self._remove_directory(self.data_dir) self._remove_directory(self.data_dir)
...@@ -169,6 +184,12 @@ class ManagedProcess: ...@@ -169,6 +184,12 @@ class ManagedProcess:
stderr=stderr, stderr=stderr,
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
) )
# Capture the child's process group id for robust cleanup even if parent shell exits
try:
self._pgid = os.getpgid(self.proc.pid)
except Exception as e:
self._logger.warning("Could not get process group id: %s", e)
self._pgid = None
self._sed_proc = subprocess.Popen( self._sed_proc = subprocess.Popen(
["sed", "-u", f"s/^/[{self._command_name.upper()}] /"], ["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
stdin=self.proc.stdout, stdin=self.proc.stdout,
...@@ -190,6 +211,12 @@ class ManagedProcess: ...@@ -190,6 +211,12 @@ class ManagedProcess:
stderr=stderr, stderr=stderr,
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
) )
# Capture the child's process group id for robust cleanup even if parent shell exits
try:
self._pgid = os.getpgid(self.proc.pid)
except Exception as e:
self._logger.warning("Could not get process group id: %s", e)
self._pgid = None
self._sed_proc = subprocess.Popen( self._sed_proc = subprocess.Popen(
["sed", "-u", f"s/^/[{self._command_name.upper()}] /"], ["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
...@@ -198,6 +225,38 @@ class ManagedProcess: ...@@ -198,6 +225,38 @@ class ManagedProcess:
) )
self._tee_proc = None self._tee_proc = None
def _terminate_process_group(self, timeout: float = 5.0):
"""Terminate the entire process group/session started for the child.
This catches cases where the launcher shell exits and its children are reparented,
leaving no parent PID to traverse, but they remain in the same process group.
"""
if self._pgid is None:
return
try:
self._logger.info("Terminating process group: %s", self._pgid)
os.killpg(self._pgid, signal.SIGTERM)
except ProcessLookupError:
return
except Exception as e:
self._logger.warning(
"Error sending SIGTERM to process group %s: %s", self._pgid, e
)
return
# Give processes a brief moment to exit gracefully
time.sleep(timeout)
# Force kill if anything remains
try:
os.killpg(self._pgid, signal.SIGKILL)
except ProcessLookupError:
pass
except Exception as e:
self._logger.warning(
"Error sending SIGKILL to process group %s: %s", self._pgid, e
)
def _remove_directory(self, path: str) -> None: def _remove_directory(self, path: str) -> None:
"""Remove a directory.""" """Remove a directory."""
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment