Unverified Commit 608405e0 authored by Keiven C's avatar Keiven C Committed by GitHub
Browse files

feat: harden ManagedProcess teardown, add xdist-safe tests (#6670)


Signed-off-by: default avatarKeiven Chang <keivenchang@users.noreply.github.com>
parent 7e07495f
......@@ -138,7 +138,7 @@ class ManagedProcess:
# ✅ SAFE for parallel tests (only kills what we launch):
ManagedProcess(
command=["vllm", "serve", "--port", "8000"],
command=["vllm", "serve", "--port", str(dynamically_allocated_port)],
terminate_all_matching_process_names=False, # Don't kill other processes
)
......@@ -237,7 +237,16 @@ class ManagedProcess:
raise
def _cleanup_stragglers(self):
"""Clean up straggler processes - called during exit and signal handling"""
"""Clean up straggler processes - called during exit and signal handling.
WARNING: NOT pytest-xdist safe! This does a system-wide sweep matching by
process name (self.stragglers) and command-line substring (self.straggler_commands),
similar to _terminate_all_matching_process_names. Skipped when
terminate_all_matching_process_names=False (i.e. xdist-safe mode) to avoid
killing other workers' processes.
"""
if not self.terminate_all_matching_process_names:
return
try:
if self.stragglers or self.straggler_commands:
self._logger.info(
......@@ -253,7 +262,9 @@ class ManagedProcess:
self._logger.info(
"Terminating Straggler %s %s", process_name, ps_process.pid
)
terminate_process_tree(ps_process.pid, self._logger)
terminate_process_tree(
ps_process.pid, self._logger, immediate_kill=True
)
# Check command line arguments
cmdline = ps_process.cmdline()
......@@ -266,7 +277,9 @@ class ManagedProcess:
ps_process.pid,
straggler_cmd,
)
terminate_process_tree(ps_process.pid, self._logger)
terminate_process_tree(
ps_process.pid, self._logger, immediate_kill=True
)
break # Avoid terminating the same process multiple times
except (
psutil.NoSuchProcess,
......@@ -285,65 +298,20 @@ class ManagedProcess:
def __exit__(self, exc_type, exc_val, exc_tb):
"""Cleanup: Terminate launched processes.
Termination Strategy (Graceful → Escalate to Force):
=====================================================
1. Snapshot child processes (before killing parent, so we can find them)
2. Send SIGTERM to process group immediately (no delay before SIGTERM)
3. Wait up to 2s (poll every 0.1s) for processes to exit
4. If still alive after 2s: Send SIGKILL (force kill)
5. Terminate individual processes (self.proc, tee, sed):
- Send SIGTERM immediately (no delay)
- Wait up to 10s for exit
- If still alive after 10s: Send SIGKILL (force kill)
6. Kill any orphaned child processes that escaped the process group
(e.g. TRT-LLM engine workers started via MPI in separate PGIDs)
7. Clean up straggler processes (if configured)
Signal Details:
- SIGTERM (15): Graceful - allows cleanup handlers to run
- SIGKILL (9): Force kill - immediate, cannot be caught
Timeout Parameter:
- Controls how long to WAIT AFTER SIGTERM before escalating to SIGKILL
- NOT a delay before sending SIGTERM (SIGTERM is sent immediately)
This ALWAYS runs regardless of terminate_all_matching_process_names setting.
Termination Strategy:
=====================
1. _stop_started_processes: calls _terminate_process_group (SIGTERM →
poll → SIGKILL all snapshotted pgids), closes pipes, waits/escalates
per-process.
2. Clean up data directory if configured.
3. Run name-based straggler cleanup (only when not in xdist-safe mode).
"""
try:
# Snapshot all child processes BEFORE killing the parent.
# Some frameworks (e.g. TRT-LLM via MPI) spawn children in separate
# process groups. After the parent dies these children become orphans
# reparented to init, so psutil can no longer find them via the parent
# PID. We must capture them now while the parent is still alive.
orphan_candidates = []
if self.proc:
try:
parent = psutil.Process(self.proc.pid)
orphan_candidates = parent.children(recursive=True)
except psutil.NoSuchProcess:
pass
self._stop_started_processes()
# Kill any child processes that survived the process group kill.
# This catches children in different PGIDs (e.g. MPI workers, engine
# cores) that os.killpg() could not reach.
for child in orphan_candidates:
try:
if child.is_running():
self._logger.info(
"Killing orphaned child process: PID %s name=%s",
child.pid,
child.name(),
)
terminate_process_tree(child.pid, self._logger)
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
if self.data_dir:
self._remove_directory(self.data_dir)
finally:
# Always run straggler cleanup, even if interrupted
self._cleanup_stragglers()
def _stop_started_processes(self, wait_timeout: float = 10.0):
......@@ -524,63 +492,108 @@ class ManagedProcess:
Kill Sequence:
==============
1. Send SIGTERM to entire process group IMMEDIATELY (no delay)
2. Wait up to `timeout` seconds (default 8s), polling every 0.1s
3. If still alive after timeout: Send SIGKILL (force kill, immediate)
Timeout Parameter:
- Controls how long to WAIT AFTER SIGTERM before escalating to SIGKILL
- NOT a delay before sending SIGTERM (SIGTERM is sent immediately)
- 8s gives engines (TRT-LLM, vLLM, etc.) enough time to gracefully
shut down MPI workers, release GPU memory, and drain pending requests
- Polling at 0.1s intervals means fast exits are not penalized
Process groups catch cases where the launcher shell exits and its
children are reparented, leaving no parent PID to traverse, but they
remain in the same process group.
1. Snapshot all child process groups (the main pgid plus any distinct
child pgids) before sending any signals.
2. Send SIGTERM to ALL snapshotted process groups (no delay).
os.killpg delivers the signal atomically to every group member.
3. Wait up to `timeout` seconds (default 8s), polling every 0.1s.
8s usually gives engines (e.g. vLLM EngineCore) enough time to flush
in-flight work and release resources before we escalate.
4. SIGKILL ALL snapshotted process groups (those that ignored SIGTERM
or timed out).
Limitation: if a child calls setsid()/setpgid() AFTER our snapshot,
its new group won't be in our set. This is rare -- Python
multiprocessing and vLLM engine core inherit the parent's group.
Post-SIGKILL resource cleanup notes:
- GPU memory: vLLM Engine releases VRAM on process death (driver-level
reclaim). No manual GPU cleanup needed after SIGKILL.
- Distributed state: NATS subscriptions, etcd leases, and KV cache
index entries are NOT cleaned up by SIGKILL. These persist until
TTL expiry or explicit purge.
- Shared memory: POSIX shm segments (/dev/shm/) survive process
death and must be unlinked separately if used.
- Network ports: TCP sockets enter TIME_WAIT (~60s); rebinding the
same port immediately may fail. Prefer allocate_port() from
tests.utils.port_utils to avoid collisions.
"""
if self._pgid is None:
return
# Step 1: Snapshot all process groups before signaling. Only self.proc runs
# in self._pgid (start_new_session=True); _tee_proc/_sed_proc are pipe
# helpers in pytest's group with no children.
all_pgids: set[int] = {self._pgid}
if self.proc and self.proc.poll() is None:
try:
for child in psutil.Process(self.proc.pid).children(recursive=True):
try:
all_pgids.add(os.getpgid(child.pid))
except (ProcessLookupError, OSError):
pass
except psutil.NoSuchProcess:
pass
# Step 2: SIGTERM all snapshotted process groups (graceful shutdown).
# Delivers the signal to every group so children that called
# setpgid()/setsid() also get a chance to shut down gracefully.
sigtermed_pgids: set[int] = set()
for pgid in all_pgids:
try:
self._logger.info("Terminating process group: %s", self._pgid)
os.killpg(self._pgid, signal.SIGTERM) # Step 1: Graceful SIGTERM
self._logger.info("Sending SIGTERM to process group: %s", pgid)
os.killpg(pgid, signal.SIGTERM)
sigtermed_pgids.add(pgid)
except ProcessLookupError:
return
pass # Already gone
except Exception as e:
self._logger.warning(
"Error sending SIGTERM to process group %s: %s", self._pgid, e
"Error sending SIGTERM to process group %s: %s", pgid, e
)
return
# Step 2: Poll for process exit instead of fixed sleep to minimize teardown time
if not sigtermed_pgids:
return # All groups already gone
# Step 3: poll for all process groups to exit.
# self.proc.poll() reaps the zombie so os.killpg(pgid, 0) can
# detect an empty group; without this the zombie keeps the group
# "alive" and the loop burns the full timeout.
poll_interval = 0.1
elapsed = 0.0
while elapsed < timeout:
# Reap the launched child if it has already exited. Without this,
# the child can remain as a zombie and keep killpg(..., 0) reporting
# the process group as alive until the timeout expires.
if self.proc is not None:
self.proc.poll()
# Check if any signaled group is still alive
still_alive = False
for pgid in sigtermed_pgids:
try:
# Check if any process in the group is still alive
os.killpg(self._pgid, 0) # Signal 0 = check existence (no kill)
except ProcessLookupError:
# Process group no longer exists - done
return
except Exception:
# Other errors (e.g., permission) - assume done
return
os.killpg(pgid, 0) # signal 0 = check existence
still_alive = True
break # At least one alive, keep waiting
except (ProcessLookupError, OSError):
pass
if not still_alive:
break
time.sleep(poll_interval)
elapsed += poll_interval
# Step 3: Force kill if anything remains after timeout
# Step 4: SIGKILL all snapshotted process groups to ensure nothing
# survives (e.g. processes that ignored SIGTERM or timed out).
# _stop_started_processes handles per-process wait/escalation afterward.
for pgid in all_pgids:
try:
os.killpg(self._pgid, signal.SIGKILL) # SIGKILL (kill -9) - immediate
os.killpg(pgid, signal.SIGKILL)
except ProcessLookupError:
pass
except PermissionError:
self._logger.warning(
"Permission denied sending SIGKILL to process group %s "
"(should not happen for our own children)",
pgid,
)
except Exception as e:
self._logger.warning(
"Error sending SIGKILL to process group %s: %s", self._pgid, e
"Error sending SIGKILL to process group %s: %s", pgid, e
)
def _remove_directory(self, path: str) -> None:
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Tests for ManagedProcess teardown behavior.
Verifies that __exit__ / _terminate_process_group correctly kills process
trees under various scenarios: simple children, deep trees, children that
create their own process groups, and xdist-safe mode skipping stragglers.
All test processes are lightweight shell/python one-liners that sleep;
no GPU or network resources are needed.
IMPORTANT: Never use generic command names like "sleep" as stragglers or
command names with terminate_all_matching_process_names=True — that kills
container infrastructure (tail -f, sleep in docker-init, etc.).
Always use unique markers scoped to the test invocation.
"""
import os
import signal
import subprocess
import time
import uuid
import psutil
import pytest
from tests.utils.managed_process import ManagedProcess
pytestmark = [
pytest.mark.parallel,
pytest.mark.gpu_0,
pytest.mark.unit,
pytest.mark.pre_merge,
]
def _unique_marker() -> str:
"""Per-call unique marker that won't collide across xdist workers."""
return f"__mp_test_{uuid.uuid4().hex[:12]}__"
def _pid_alive(pid: int) -> bool:
"""Check whether a PID is still running (zombies count as dead)."""
try:
p = psutil.Process(pid)
return p.status() != psutil.STATUS_ZOMBIE
except (psutil.NoSuchProcess, psutil.AccessDenied):
return False
def _wait_for_pid_death(pid: int, timeout: float = 10.0) -> bool:
"""Poll until PID is dead or timeout. Returns True if dead."""
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if not _pid_alive(pid):
return True
time.sleep(0.1)
return False
def _collect_tree_pids(root_pid: int) -> set[int]:
"""Return {root_pid} union all descendant PIDs."""
pids = set()
try:
parent = psutil.Process(root_pid)
pids.add(root_pid)
for child in parent.children(recursive=True):
pids.add(child.pid)
except psutil.NoSuchProcess:
pass
return pids
def _wait_for_tree(
root_pid: int, min_count: int, timeout: float = 3.0, poll: float = 0.1
) -> set[int]:
"""Poll until the process tree has at least min_count members."""
deadline = time.monotonic() + timeout
pids: set[int] = set()
while time.monotonic() < deadline:
pids = _collect_tree_pids(root_pid)
if len(pids) >= min_count:
return pids
time.sleep(poll)
return pids
def _bash_sleep_cmd(marker: str, tag: str = "") -> list[str]:
"""Return a bash command that sleeps 300s with an embedded unique marker.
The trailing `: noexit` prevents bash from exec-ing into sleep
(which would lose the marker from the cmdline)."""
return ["bash", "-c", f": {marker}{tag}; sleep 300; : noexit"]
# ---------------------------------------------------------------------------
# Scenario 1: Simple process with children — all should die on __exit__
# ---------------------------------------------------------------------------
class TestSimpleProcessTree:
def test_parent_and_children_killed(self, tmp_path):
"""A parent that forks children; all should be dead after __exit__."""
marker = _unique_marker()
mp = ManagedProcess(
command=[
"bash",
"-c",
f": {marker}; sleep 300 & sleep 300 & wait",
],
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
log_dir=str(tmp_path),
)
with mp:
assert mp.proc is not None
root_pid = mp.proc.pid
tree_pids = _wait_for_tree(root_pid, min_count=2)
assert len(tree_pids) >= 2, f"Expected parent + children, got {tree_pids}"
for pid in tree_pids:
assert _wait_for_pid_death(
pid, timeout=10
), f"PID {pid} still alive after teardown"
# ---------------------------------------------------------------------------
# Scenario 2: Deep process tree (grandchildren)
# ---------------------------------------------------------------------------
class TestDeepProcessTree:
def test_grandchildren_killed(self, tmp_path):
"""Parent -> child -> grandchild; all should be dead after __exit__."""
marker = _unique_marker()
mp = ManagedProcess(
command=[
"bash",
"-c",
f": {marker}; bash -c 'bash -c \"sleep 300\" & wait' & wait",
],
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
log_dir=str(tmp_path),
)
with mp:
assert mp.proc is not None
root_pid = mp.proc.pid
tree_pids = _wait_for_tree(root_pid, min_count=3)
assert (
len(tree_pids) >= 3
), f"Expected parent + child + grandchild, got {tree_pids}"
for pid in tree_pids:
assert _wait_for_pid_death(
pid, timeout=10
), f"PID {pid} still alive after teardown"
# ---------------------------------------------------------------------------
# Scenario 3: Child creates its own process group (setpgid)
# ---------------------------------------------------------------------------
class TestChildWithOwnProcessGroup:
def test_child_in_own_pgid_killed(self, tmp_path):
"""A child that calls setpgid(0,0) to leave the parent's group
should still be killed via the snapshotted pgid set."""
script = (
"import os, time; "
"pid = os.fork(); "
"_ = (os.setpgid(0, 0), time.sleep(300)) if pid == 0 else "
"(time.sleep(0.3), time.sleep(300))"
)
mp = ManagedProcess(
command=["python3", "-c", script],
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
log_dir=str(tmp_path),
)
with mp:
assert mp.proc is not None
root_pid = mp.proc.pid
tree_pids = _wait_for_tree(root_pid, min_count=2)
assert len(tree_pids) >= 2, f"Expected parent + child, got {tree_pids}"
child_pids = tree_pids - {root_pid}
parent_pgid = os.getpgid(root_pid)
found_separate_pgid = False
for cpid in child_pids:
try:
if os.getpgid(cpid) != parent_pgid:
found_separate_pgid = True
break
except (ProcessLookupError, OSError):
pass
if not found_separate_pgid:
pytest.skip("Child didn't get a separate pgid (OS-dependent)")
for pid in tree_pids:
assert _wait_for_pid_death(
pid, timeout=10
), f"PID {pid} still alive after teardown (separate pgid scenario)"
# ---------------------------------------------------------------------------
# Scenario 4: xdist-safe mode skips _cleanup_stragglers
# ---------------------------------------------------------------------------
class TestXdistSafeSkipsStragglers:
def test_stragglers_not_killed_in_xdist_mode(self, tmp_path):
"""With terminate_all_matching_process_names=False, _cleanup_stragglers
should NOT kill unrelated processes matching the straggler pattern."""
marker = _unique_marker()
bystander = subprocess.Popen(
_bash_sleep_cmd(marker, "bystander"),
start_new_session=True,
)
bystander_pid = bystander.pid
try:
mp = ManagedProcess(
command=_bash_sleep_cmd(marker, "main"),
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
straggler_commands=[marker],
log_dir=str(tmp_path),
)
with mp:
pass
assert _pid_alive(
bystander_pid
), "Bystander was killed even though xdist-safe mode was on"
finally:
try:
os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
except (ProcessLookupError, PermissionError, OSError):
pass
# Reap the zombie so it doesn't linger in the process table
# for the rest of the pytest session.
try:
bystander.wait(timeout=2)
except subprocess.TimeoutExpired:
try:
os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
except (ProcessLookupError, PermissionError, OSError):
pass
def test_stragglers_killed_when_not_xdist_mode(self, tmp_path):
"""With terminate_all_matching_process_names=True, _cleanup_stragglers
SHOULD kill processes matching the straggler pattern."""
marker = _unique_marker()
victim_tag = f"{marker}_victim"
launcher_tag = f"{marker}_launcher"
bystander = subprocess.Popen(
["bash", "-c", f": {victim_tag}; sleep 300; : noexit"],
start_new_session=True,
)
bystander_pid = bystander.pid
try:
mp = ManagedProcess(
command=["bash", "-c", f": {launcher_tag}; sleep 1"],
timeout=10,
display_output=False,
display_name=launcher_tag,
terminate_all_matching_process_names=True,
straggler_commands=[victim_tag],
log_dir=str(tmp_path),
)
with mp:
time.sleep(0.5)
assert _wait_for_pid_death(
bystander_pid, timeout=10
), "Bystander with matching straggler command should have been killed"
finally:
try:
os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
except (ProcessLookupError, PermissionError, OSError):
pass
# Reap the zombie so it doesn't linger in the process table
# for the rest of the pytest session.
try:
bystander.wait(timeout=2)
except subprocess.TimeoutExpired:
try:
os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
except (ProcessLookupError, PermissionError, OSError):
pass
# ---------------------------------------------------------------------------
# Scenario 5: Process already dead before __exit__
# ---------------------------------------------------------------------------
class TestAlreadyDeadProcess:
def test_exit_handles_dead_process(self, tmp_path):
"""If the process exits on its own before __exit__, teardown should
not raise."""
mp = ManagedProcess(
command=["bash", "-c", "exit 0"],
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
log_dir=str(tmp_path),
)
with mp:
time.sleep(0.5)
# No exception = pass
# ---------------------------------------------------------------------------
# Scenario 6: SIGTERM grace period — process that traps SIGTERM
# ---------------------------------------------------------------------------
class TestSigtermGracePeriod:
def test_process_gets_sigterm_grace_before_sigkill(self, tmp_path):
"""A process that handles SIGTERM and takes a moment to exit should
get the grace period, not be immediately SIGKILLed.
Uses a Python child that writes a "ready" file after installing its
SIGTERM handler, so we don't race against interpreter startup."""
marker_file = str(tmp_path / "got_sigterm")
ready_file = str(tmp_path / "ready")
script = (
"import os, signal, time, pathlib; "
f"marker = pathlib.Path('{marker_file}'); "
f"ready = pathlib.Path('{ready_file}'); "
"signal.signal(signal.SIGTERM, "
"lambda *_: (marker.touch(), os._exit(0))); "
"ready.touch(); "
"[time.sleep(0.1) for _ in iter(int, 1)]"
)
mp = ManagedProcess(
command=["python3", "-c", script],
timeout=10,
display_output=False,
terminate_all_matching_process_names=False,
log_dir=str(tmp_path),
)
with mp:
assert mp.proc is not None
deadline = time.monotonic() + 5.0
while not os.path.exists(ready_file):
assert time.monotonic() < deadline, "Child never became ready"
time.sleep(0.05)
assert os.path.exists(
marker_file
), "Process was SIGKILLed before SIGTERM handler could run"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment