Unverified Commit 8c2072cf authored by Alec's avatar Alec Committed by GitHub
Browse files

fix: [trtllm] add wait_for_instance before register_llm (#2683)


Signed-off-by: default avataralec-flowers <aflowers@nvidia.com>
parent 63f5bbc0
......@@ -239,17 +239,6 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.reasoning_parser = config.reasoning_parser
runtime_config.tool_call_parser = config.tool_call_parser
if is_first_worker(config):
# Register the model with runtime config
await register_llm(
modelType,
endpoint,
config.model_path,
config.served_model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
runtime_config=runtime_config,
)
# publisher will be set later if publishing is enabled.
handler_config = RequestHandlerConfig(
component=component,
......@@ -262,6 +251,23 @@ async def init(runtime: DistributedRuntime, config: Config):
multimodal_processor=multimodal_processor,
)
if next_client:
logging.info(
f"Waiting for the next endpoint to be ready: {config.next_endpoint}"
)
await next_client.wait_for_instances()
if is_first_worker(config):
# Register the model with runtime config
await register_llm(
modelType,
endpoint,
config.model_path,
config.served_model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
)
if config.publish_events_and_metrics and is_first_worker(config):
# Initialize and pass in the publisher to the request handler to
# publish events and metrics.
......
......@@ -23,7 +23,7 @@ class EngineConfig:
endpoints: List[str]
response_handlers: List[Callable[[Any], str]]
model: str
timeout: int = 120
timeout: int = 600
delayed_start: int = 0
......
......@@ -22,8 +22,6 @@ logger = logging.getLogger(__name__)
class TRTLLMConfig(EngineConfig):
"""Configuration for trtllm test scenarios"""
timeout: int = 60
class TRTLLMProcess(EngineProcess):
"""Simple process manager for trtllm shell scripts"""
......@@ -71,9 +69,7 @@ trtllm_configs = {
chat_completions_response_handler,
completions_response_handler,
],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
delayed_start=0,
timeout=360,
model="Qwen/Qwen3-0.6B",
),
"disaggregated": TRTLLMConfig(
name="disaggregated",
......@@ -85,9 +81,7 @@ trtllm_configs = {
chat_completions_response_handler,
completions_response_handler,
],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
delayed_start=0,
timeout=360,
model="Qwen/Qwen3-0.6B",
),
# TODO: These are sanity tests that the kv router examples launch
# and inference without error, but do not do detailed checks on the
......@@ -102,9 +96,7 @@ trtllm_configs = {
chat_completions_response_handler,
completions_response_handler,
],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
delayed_start=0,
timeout=360,
model="Qwen/Qwen3-0.6B",
),
"disaggregated_router": TRTLLMConfig(
name="disaggregated_router",
......@@ -116,9 +108,7 @@ trtllm_configs = {
chat_completions_response_handler,
completions_response_handler,
],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
delayed_start=0,
timeout=360,
model="Qwen/Qwen3-0.6B",
),
}
......
......@@ -133,8 +133,6 @@ vllm_configs = {
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
delayed_start=0,
timeout=360,
),
"agg-router": VLLMConfig(
name="agg-router",
......@@ -147,8 +145,6 @@ vllm_configs = {
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
delayed_start=0,
timeout=360,
),
"disaggregated": VLLMConfig(
name="disaggregated",
......@@ -161,8 +157,6 @@ vllm_configs = {
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
delayed_start=0,
timeout=360,
),
"deepep": VLLMConfig(
name="deepep",
......@@ -179,7 +173,6 @@ vllm_configs = {
completions_response_handler,
],
model="deepseek-ai/DeepSeek-V2-Lite",
delayed_start=0,
args=[
"--model",
"deepseek-ai/DeepSeek-V2-Lite",
......@@ -190,7 +183,7 @@ vllm_configs = {
"--gpus-per-node",
"2",
],
timeout=560,
timeout=700,
),
"multimodal_agg_llava": VLLMConfig(
name="multimodal_agg_llava",
......@@ -202,9 +195,7 @@ vllm_configs = {
chat_completions_response_handler,
],
model="llava-hf/llava-1.5-7b-hf",
delayed_start=0,
args=["--model", "llava-hf/llava-1.5-7b-hf"],
timeout=360,
),
"multimodal_agg_qwen": VLLMConfig(
name="multimodal_agg_qwen",
......
......@@ -17,6 +17,7 @@ import json
import logging
import os
import shutil
import signal
import socket
import subprocess
import time
......@@ -82,6 +83,10 @@ class ManagedProcess:
straggler_commands: List[str] = field(default_factory=list)
log_dir: str = os.getcwd()
# Ensure attributes exist even if startup fails early
proc: Optional[subprocess.Popen] = None
_pgid: Optional[int] = None
_logger = logging.getLogger()
_command_name = None
_log_path = None
......@@ -107,20 +112,30 @@ class ManagedProcess:
return self
except Exception as e:
self.__exit__(None, None, None)
raise e
except Exception:
try:
self.__exit__(None, None, None)
except Exception as cleanup_err:
self._logger.warning(
"Error during cleanup in __enter__: %s", cleanup_err
)
raise
def __exit__(self, exc_type, exc_val, exc_tb):
self._terminate_process_group()
process_list = [self.proc, self._tee_proc, self._sed_proc]
for process in process_list:
if process:
if process.stdout:
process.stdout.close()
if process.stdin:
process.stdin.close()
terminate_process_tree(process.pid, self._logger)
process.wait()
try:
if process.stdout:
process.stdout.close()
if process.stdin:
process.stdin.close()
terminate_process_tree(process.pid, self._logger)
process.wait()
except Exception as e:
self._logger.warning("Error terminating process: %s", e)
if self.data_dir:
self._remove_directory(self.data_dir)
......@@ -169,6 +184,12 @@ class ManagedProcess:
stderr=stderr,
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
)
# Capture the child's process group id for robust cleanup even if parent shell exits
try:
self._pgid = os.getpgid(self.proc.pid)
except Exception as e:
self._logger.warning("Could not get process group id: %s", e)
self._pgid = None
self._sed_proc = subprocess.Popen(
["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
stdin=self.proc.stdout,
......@@ -190,6 +211,12 @@ class ManagedProcess:
stderr=stderr,
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
)
# Capture the child's process group id for robust cleanup even if parent shell exits
try:
self._pgid = os.getpgid(self.proc.pid)
except Exception as e:
self._logger.warning("Could not get process group id: %s", e)
self._pgid = None
self._sed_proc = subprocess.Popen(
["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
......@@ -198,6 +225,38 @@ class ManagedProcess:
)
self._tee_proc = None
def _terminate_process_group(self, timeout: float = 5.0):
"""Terminate the entire process group/session started for the child.
This catches cases where the launcher shell exits and its children are reparented,
leaving no parent PID to traverse, but they remain in the same process group.
"""
if self._pgid is None:
return
try:
self._logger.info("Terminating process group: %s", self._pgid)
os.killpg(self._pgid, signal.SIGTERM)
except ProcessLookupError:
return
except Exception as e:
self._logger.warning(
"Error sending SIGTERM to process group %s: %s", self._pgid, e
)
return
# Give processes a brief moment to exit gracefully
time.sleep(timeout)
# Force kill if anything remains
try:
os.killpg(self._pgid, signal.SIGKILL)
except ProcessLookupError:
pass
except Exception as e:
self._logger.warning(
"Error sending SIGKILL to process group %s: %s", self._pgid, e
)
def _remove_directory(self, path: str) -> None:
"""Remove a directory."""
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment