"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "2b5655fd43fa45f002a532d0d8239c4b4d99ac71"
Unverified Commit 608405e0 authored by Keiven C's avatar Keiven C Committed by GitHub
Browse files

feat: harden ManagedProcess teardown, add xdist-safe tests (#6670)


Signed-off-by: default avatarKeiven Chang <keivenchang@users.noreply.github.com>
parent 7e07495f
...@@ -138,7 +138,7 @@ class ManagedProcess: ...@@ -138,7 +138,7 @@ class ManagedProcess:
# ✅ SAFE for parallel tests (only kills what we launch): # ✅ SAFE for parallel tests (only kills what we launch):
ManagedProcess( ManagedProcess(
command=["vllm", "serve", "--port", "8000"], command=["vllm", "serve", "--port", str(dynamically_allocated_port)],
terminate_all_matching_process_names=False, # Don't kill other processes terminate_all_matching_process_names=False, # Don't kill other processes
) )
...@@ -237,7 +237,16 @@ class ManagedProcess: ...@@ -237,7 +237,16 @@ class ManagedProcess:
raise raise
def _cleanup_stragglers(self): def _cleanup_stragglers(self):
"""Clean up straggler processes - called during exit and signal handling""" """Clean up straggler processes - called during exit and signal handling.
WARNING: NOT pytest-xdist safe! This does a system-wide sweep matching by
process name (self.stragglers) and command-line substring (self.straggler_commands),
similar to _terminate_all_matching_process_names. Skipped when
terminate_all_matching_process_names=False (i.e. xdist-safe mode) to avoid
killing other workers' processes.
"""
if not self.terminate_all_matching_process_names:
return
try: try:
if self.stragglers or self.straggler_commands: if self.stragglers or self.straggler_commands:
self._logger.info( self._logger.info(
...@@ -253,7 +262,9 @@ class ManagedProcess: ...@@ -253,7 +262,9 @@ class ManagedProcess:
self._logger.info( self._logger.info(
"Terminating Straggler %s %s", process_name, ps_process.pid "Terminating Straggler %s %s", process_name, ps_process.pid
) )
terminate_process_tree(ps_process.pid, self._logger) terminate_process_tree(
ps_process.pid, self._logger, immediate_kill=True
)
# Check command line arguments # Check command line arguments
cmdline = ps_process.cmdline() cmdline = ps_process.cmdline()
...@@ -266,7 +277,9 @@ class ManagedProcess: ...@@ -266,7 +277,9 @@ class ManagedProcess:
ps_process.pid, ps_process.pid,
straggler_cmd, straggler_cmd,
) )
terminate_process_tree(ps_process.pid, self._logger) terminate_process_tree(
ps_process.pid, self._logger, immediate_kill=True
)
break # Avoid terminating the same process multiple times break # Avoid terminating the same process multiple times
except ( except (
psutil.NoSuchProcess, psutil.NoSuchProcess,
...@@ -285,65 +298,20 @@ class ManagedProcess: ...@@ -285,65 +298,20 @@ class ManagedProcess:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""Cleanup: Terminate launched processes. """Cleanup: Terminate launched processes.
Termination Strategy (Graceful → Escalate to Force): Termination Strategy:
===================================================== =====================
1. Snapshot child processes (before killing parent, so we can find them) 1. _stop_started_processes: calls _terminate_process_group (SIGTERM →
2. Send SIGTERM to process group immediately (no delay before SIGTERM) poll → SIGKILL all snapshotted pgids), closes pipes, waits/escalates
3. Wait up to 2s (poll every 0.1s) for processes to exit per-process.
4. If still alive after 2s: Send SIGKILL (force kill) 2. Clean up data directory if configured.
5. Terminate individual processes (self.proc, tee, sed): 3. Run name-based straggler cleanup (only when not in xdist-safe mode).
- Send SIGTERM immediately (no delay)
- Wait up to 10s for exit
- If still alive after 10s: Send SIGKILL (force kill)
6. Kill any orphaned child processes that escaped the process group
(e.g. TRT-LLM engine workers started via MPI in separate PGIDs)
7. Clean up straggler processes (if configured)
Signal Details:
- SIGTERM (15): Graceful - allows cleanup handlers to run
- SIGKILL (9): Force kill - immediate, cannot be caught
Timeout Parameter:
- Controls how long to WAIT AFTER SIGTERM before escalating to SIGKILL
- NOT a delay before sending SIGTERM (SIGTERM is sent immediately)
This ALWAYS runs regardless of terminate_all_matching_process_names setting.
""" """
try: try:
# Snapshot all child processes BEFORE killing the parent.
# Some frameworks (e.g. TRT-LLM via MPI) spawn children in separate
# process groups. After the parent dies these children become orphans
# reparented to init, so psutil can no longer find them via the parent
# PID. We must capture them now while the parent is still alive.
orphan_candidates = []
if self.proc:
try:
parent = psutil.Process(self.proc.pid)
orphan_candidates = parent.children(recursive=True)
except psutil.NoSuchProcess:
pass
self._stop_started_processes() self._stop_started_processes()
# Kill any child processes that survived the process group kill.
# This catches children in different PGIDs (e.g. MPI workers, engine
# cores) that os.killpg() could not reach.
for child in orphan_candidates:
try:
if child.is_running():
self._logger.info(
"Killing orphaned child process: PID %s name=%s",
child.pid,
child.name(),
)
terminate_process_tree(child.pid, self._logger)
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
if self.data_dir: if self.data_dir:
self._remove_directory(self.data_dir) self._remove_directory(self.data_dir)
finally: finally:
# Always run straggler cleanup, even if interrupted
self._cleanup_stragglers() self._cleanup_stragglers()
def _stop_started_processes(self, wait_timeout: float = 10.0): def _stop_started_processes(self, wait_timeout: float = 10.0):
...@@ -524,64 +492,109 @@ class ManagedProcess: ...@@ -524,64 +492,109 @@ class ManagedProcess:
Kill Sequence: Kill Sequence:
============== ==============
1. Send SIGTERM to entire process group IMMEDIATELY (no delay) 1. Snapshot all child process groups (the main pgid plus any distinct
2. Wait up to `timeout` seconds (default 8s), polling every 0.1s child pgids) before sending any signals.
3. If still alive after timeout: Send SIGKILL (force kill, immediate) 2. Send SIGTERM to ALL snapshotted process groups (no delay).
os.killpg delivers the signal atomically to every group member.
Timeout Parameter: 3. Wait up to `timeout` seconds (default 8s), polling every 0.1s.
- Controls how long to WAIT AFTER SIGTERM before escalating to SIGKILL 8s usually gives engines (e.g. vLLM EngineCore) enough time to flush
- NOT a delay before sending SIGTERM (SIGTERM is sent immediately) in-flight work and release resources before we escalate.
- 8s gives engines (TRT-LLM, vLLM, etc.) enough time to gracefully 4. SIGKILL ALL snapshotted process groups (those that ignored SIGTERM
shut down MPI workers, release GPU memory, and drain pending requests or timed out).
- Polling at 0.1s intervals means fast exits are not penalized
Limitation: if a child calls setsid()/setpgid() AFTER our snapshot,
Process groups catch cases where the launcher shell exits and its its new group won't be in our set. This is rare -- Python
children are reparented, leaving no parent PID to traverse, but they multiprocessing and vLLM engine core inherit the parent's group.
remain in the same process group.
Post-SIGKILL resource cleanup notes:
- GPU memory: vLLM Engine releases VRAM on process death (driver-level
reclaim). No manual GPU cleanup needed after SIGKILL.
- Distributed state: NATS subscriptions, etcd leases, and KV cache
index entries are NOT cleaned up by SIGKILL. These persist until
TTL expiry or explicit purge.
- Shared memory: POSIX shm segments (/dev/shm/) survive process
death and must be unlinked separately if used.
- Network ports: TCP sockets enter TIME_WAIT (~60s); rebinding the
same port immediately may fail. Prefer allocate_port() from
tests.utils.port_utils to avoid collisions.
""" """
if self._pgid is None: if self._pgid is None:
return return
try:
self._logger.info("Terminating process group: %s", self._pgid)
os.killpg(self._pgid, signal.SIGTERM) # Step 1: Graceful SIGTERM
except ProcessLookupError:
return
except Exception as e:
self._logger.warning(
"Error sending SIGTERM to process group %s: %s", self._pgid, e
)
return
# Step 2: Poll for process exit instead of fixed sleep to minimize teardown time # Step 1: Snapshot all process groups before signaling. Only self.proc runs
# in self._pgid (start_new_session=True); _tee_proc/_sed_proc are pipe
# helpers in pytest's group with no children.
all_pgids: set[int] = {self._pgid}
if self.proc and self.proc.poll() is None:
try:
for child in psutil.Process(self.proc.pid).children(recursive=True):
try:
all_pgids.add(os.getpgid(child.pid))
except (ProcessLookupError, OSError):
pass
except psutil.NoSuchProcess:
pass
# Step 2: SIGTERM all snapshotted process groups (graceful shutdown).
# Delivers the signal to every group so children that called
# setpgid()/setsid() also get a chance to shut down gracefully.
sigtermed_pgids: set[int] = set()
for pgid in all_pgids:
try:
self._logger.info("Sending SIGTERM to process group: %s", pgid)
os.killpg(pgid, signal.SIGTERM)
sigtermed_pgids.add(pgid)
except ProcessLookupError:
pass # Already gone
except Exception as e:
self._logger.warning(
"Error sending SIGTERM to process group %s: %s", pgid, e
)
if not sigtermed_pgids:
return # All groups already gone
# Step 3: poll for all process groups to exit.
# self.proc.poll() reaps the zombie so os.killpg(pgid, 0) can
# detect an empty group; without this the zombie keeps the group
# "alive" and the loop burns the full timeout.
poll_interval = 0.1 poll_interval = 0.1
elapsed = 0.0 elapsed = 0.0
while elapsed < timeout: while elapsed < timeout:
# Reap the launched child if it has already exited. Without this,
# the child can remain as a zombie and keep killpg(..., 0) reporting
# the process group as alive until the timeout expires.
if self.proc is not None: if self.proc is not None:
self.proc.poll() self.proc.poll()
try: # Check if any signaled group is still alive
# Check if any process in the group is still alive still_alive = False
os.killpg(self._pgid, 0) # Signal 0 = check existence (no kill) for pgid in sigtermed_pgids:
except ProcessLookupError: try:
# Process group no longer exists - done os.killpg(pgid, 0) # signal 0 = check existence
return still_alive = True
except Exception: break # At least one alive, keep waiting
# Other errors (e.g., permission) - assume done except (ProcessLookupError, OSError):
return pass
if not still_alive:
break
time.sleep(poll_interval) time.sleep(poll_interval)
elapsed += poll_interval elapsed += poll_interval
# Step 3: Force kill if anything remains after timeout # Step 4: SIGKILL all snapshotted process groups to ensure nothing
try: # survives (e.g. processes that ignored SIGTERM or timed out).
os.killpg(self._pgid, signal.SIGKILL) # SIGKILL (kill -9) - immediate # _stop_started_processes handles per-process wait/escalation afterward.
except ProcessLookupError: for pgid in all_pgids:
pass try:
except Exception as e: os.killpg(pgid, signal.SIGKILL)
self._logger.warning( except ProcessLookupError:
"Error sending SIGKILL to process group %s: %s", self._pgid, e pass
) except PermissionError:
self._logger.warning(
"Permission denied sending SIGKILL to process group %s "
"(should not happen for our own children)",
pgid,
)
except Exception as e:
self._logger.warning(
"Error sending SIGKILL to process group %s: %s", pgid, e
)
def _remove_directory(self, path: str) -> None: def _remove_directory(self, path: str) -> None:
"""Remove a directory.""" """Remove a directory."""
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Tests for ManagedProcess teardown behavior.
Verifies that __exit__ / _terminate_process_group correctly kills process
trees under various scenarios: simple children, deep trees, children that
create their own process groups, and xdist-safe mode skipping stragglers.
All test processes are lightweight shell/python one-liners that sleep;
no GPU or network resources are needed.
IMPORTANT: Never use generic command names like "sleep" as stragglers or
command names with terminate_all_matching_process_names=True — that kills
container infrastructure (tail -f, sleep in docker-init, etc.).
Always use unique markers scoped to the test invocation.
"""
import os
import signal
import subprocess
import time
import uuid
import psutil
import pytest
from tests.utils.managed_process import ManagedProcess
pytestmark = [
pytest.mark.parallel,
pytest.mark.gpu_0,
pytest.mark.unit,
pytest.mark.pre_merge,
]
def _unique_marker() -> str:
"""Per-call unique marker that won't collide across xdist workers."""
return f"__mp_test_{uuid.uuid4().hex[:12]}__"
def _pid_alive(pid: int) -> bool:
"""Check whether a PID is still running (zombies count as dead)."""
try:
p = psutil.Process(pid)
return p.status() != psutil.STATUS_ZOMBIE
except (psutil.NoSuchProcess, psutil.AccessDenied):
return False
def _wait_for_pid_death(pid: int, timeout: float = 10.0) -> bool:
"""Poll until PID is dead or timeout. Returns True if dead."""
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if not _pid_alive(pid):
return True
time.sleep(0.1)
return False
def _collect_tree_pids(root_pid: int) -> set[int]:
"""Return {root_pid} union all descendant PIDs."""
pids = set()
try:
parent = psutil.Process(root_pid)
pids.add(root_pid)
for child in parent.children(recursive=True):
pids.add(child.pid)
except psutil.NoSuchProcess:
pass
return pids
def _wait_for_tree(
root_pid: int, min_count: int, timeout: float = 3.0, poll: float = 0.1
) -> set[int]:
"""Poll until the process tree has at least min_count members."""
deadline = time.monotonic() + timeout
pids: set[int] = set()
while time.monotonic() < deadline:
pids = _collect_tree_pids(root_pid)
if len(pids) >= min_count:
return pids
time.sleep(poll)
return pids
def _bash_sleep_cmd(marker: str, tag: str = "") -> list[str]:
"""Return a bash command that sleeps 300s with an embedded unique marker.
The trailing `: noexit` prevents bash from exec-ing into sleep
(which would lose the marker from the cmdline)."""
return ["bash", "-c", f": {marker}{tag}; sleep 300; : noexit"]
# ---------------------------------------------------------------------------
# Scenario 1: Simple process with children — all should die on __exit__
# ---------------------------------------------------------------------------
class TestSimpleProcessTree:
def test_parent_and_children_killed(self, tmp_path):
"""A parent that forks children; all should be dead after __exit__."""
marker = _unique_marker()
mp = ManagedProcess(
command=[
"bash",
"-c",
f": {marker}; sleep 300 & sleep 300 & wait",
],
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
log_dir=str(tmp_path),
)
with mp:
assert mp.proc is not None
root_pid = mp.proc.pid
tree_pids = _wait_for_tree(root_pid, min_count=2)
assert len(tree_pids) >= 2, f"Expected parent + children, got {tree_pids}"
for pid in tree_pids:
assert _wait_for_pid_death(
pid, timeout=10
), f"PID {pid} still alive after teardown"
# ---------------------------------------------------------------------------
# Scenario 2: Deep process tree (grandchildren)
# ---------------------------------------------------------------------------
class TestDeepProcessTree:
def test_grandchildren_killed(self, tmp_path):
"""Parent -> child -> grandchild; all should be dead after __exit__."""
marker = _unique_marker()
mp = ManagedProcess(
command=[
"bash",
"-c",
f": {marker}; bash -c 'bash -c \"sleep 300\" & wait' & wait",
],
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
log_dir=str(tmp_path),
)
with mp:
assert mp.proc is not None
root_pid = mp.proc.pid
tree_pids = _wait_for_tree(root_pid, min_count=3)
assert (
len(tree_pids) >= 3
), f"Expected parent + child + grandchild, got {tree_pids}"
for pid in tree_pids:
assert _wait_for_pid_death(
pid, timeout=10
), f"PID {pid} still alive after teardown"
# ---------------------------------------------------------------------------
# Scenario 3: Child creates its own process group (setpgid)
# ---------------------------------------------------------------------------
class TestChildWithOwnProcessGroup:
def test_child_in_own_pgid_killed(self, tmp_path):
"""A child that calls setpgid(0,0) to leave the parent's group
should still be killed via the snapshotted pgid set."""
script = (
"import os, time; "
"pid = os.fork(); "
"_ = (os.setpgid(0, 0), time.sleep(300)) if pid == 0 else "
"(time.sleep(0.3), time.sleep(300))"
)
mp = ManagedProcess(
command=["python3", "-c", script],
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
log_dir=str(tmp_path),
)
with mp:
assert mp.proc is not None
root_pid = mp.proc.pid
tree_pids = _wait_for_tree(root_pid, min_count=2)
assert len(tree_pids) >= 2, f"Expected parent + child, got {tree_pids}"
child_pids = tree_pids - {root_pid}
parent_pgid = os.getpgid(root_pid)
found_separate_pgid = False
for cpid in child_pids:
try:
if os.getpgid(cpid) != parent_pgid:
found_separate_pgid = True
break
except (ProcessLookupError, OSError):
pass
if not found_separate_pgid:
pytest.skip("Child didn't get a separate pgid (OS-dependent)")
for pid in tree_pids:
assert _wait_for_pid_death(
pid, timeout=10
), f"PID {pid} still alive after teardown (separate pgid scenario)"
# ---------------------------------------------------------------------------
# Scenario 4: xdist-safe mode skips _cleanup_stragglers
# ---------------------------------------------------------------------------
class TestXdistSafeSkipsStragglers:
def test_stragglers_not_killed_in_xdist_mode(self, tmp_path):
"""With terminate_all_matching_process_names=False, _cleanup_stragglers
should NOT kill unrelated processes matching the straggler pattern."""
marker = _unique_marker()
bystander = subprocess.Popen(
_bash_sleep_cmd(marker, "bystander"),
start_new_session=True,
)
bystander_pid = bystander.pid
try:
mp = ManagedProcess(
command=_bash_sleep_cmd(marker, "main"),
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
straggler_commands=[marker],
log_dir=str(tmp_path),
)
with mp:
pass
assert _pid_alive(
bystander_pid
), "Bystander was killed even though xdist-safe mode was on"
finally:
try:
os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
except (ProcessLookupError, PermissionError, OSError):
pass
# Reap the zombie so it doesn't linger in the process table
# for the rest of the pytest session.
try:
bystander.wait(timeout=2)
except subprocess.TimeoutExpired:
try:
os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
except (ProcessLookupError, PermissionError, OSError):
pass
def test_stragglers_killed_when_not_xdist_mode(self, tmp_path):
"""With terminate_all_matching_process_names=True, _cleanup_stragglers
SHOULD kill processes matching the straggler pattern."""
marker = _unique_marker()
victim_tag = f"{marker}_victim"
launcher_tag = f"{marker}_launcher"
bystander = subprocess.Popen(
["bash", "-c", f": {victim_tag}; sleep 300; : noexit"],
start_new_session=True,
)
bystander_pid = bystander.pid
try:
mp = ManagedProcess(
command=["bash", "-c", f": {launcher_tag}; sleep 1"],
timeout=10,
display_output=False,
display_name=launcher_tag,
terminate_all_matching_process_names=True,
straggler_commands=[victim_tag],
log_dir=str(tmp_path),
)
with mp:
time.sleep(0.5)
assert _wait_for_pid_death(
bystander_pid, timeout=10
), "Bystander with matching straggler command should have been killed"
finally:
try:
os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
except (ProcessLookupError, PermissionError, OSError):
pass
# Reap the zombie so it doesn't linger in the process table
# for the rest of the pytest session.
try:
bystander.wait(timeout=2)
except subprocess.TimeoutExpired:
try:
os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
except (ProcessLookupError, PermissionError, OSError):
pass
# ---------------------------------------------------------------------------
# Scenario 5: Process already dead before __exit__
# ---------------------------------------------------------------------------
class TestAlreadyDeadProcess:
def test_exit_handles_dead_process(self, tmp_path):
"""If the process exits on its own before __exit__, teardown should
not raise."""
mp = ManagedProcess(
command=["bash", "-c", "exit 0"],
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
log_dir=str(tmp_path),
)
with mp:
time.sleep(0.5)
# No exception = pass
# ---------------------------------------------------------------------------
# Scenario 6: SIGTERM grace period — process that traps SIGTERM
# ---------------------------------------------------------------------------
class TestSigtermGracePeriod:
def test_process_gets_sigterm_grace_before_sigkill(self, tmp_path):
"""A process that handles SIGTERM and takes a moment to exit should
get the grace period, not be immediately SIGKILLed.
Uses a Python child that writes a "ready" file after installing its
SIGTERM handler, so we don't race against interpreter startup."""
marker_file = str(tmp_path / "got_sigterm")
ready_file = str(tmp_path / "ready")
script = (
"import os, signal, time, pathlib; "
f"marker = pathlib.Path('{marker_file}'); "
f"ready = pathlib.Path('{ready_file}'); "
"signal.signal(signal.SIGTERM, "
"lambda *_: (marker.touch(), os._exit(0))); "
"ready.touch(); "
"[time.sleep(0.1) for _ in iter(int, 1)]"
)
mp = ManagedProcess(
command=["python3", "-c", script],
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
log_dir=str(tmp_path),
)
with mp:
assert mp.proc is not None
deadline = time.monotonic() + 5.0
while not os.path.exists(ready_file):
assert time.monotonic() < deadline, "Child never became ready"
time.sleep(0.05)
assert os.path.exists(
marker_file
), "Process was SIGKILLed before SIGTERM handler could run"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment