test_managed_process_teardown.py 12.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Tests for ManagedProcess teardown behavior.

Verifies that __exit__ / _terminate_process_group correctly kills process
trees under various scenarios: simple children, deep trees, children that
create their own process groups, and xdist-safe mode skipping stragglers.

All test processes are lightweight shell/python one-liners that sleep;
no GPU or network resources are needed.

IMPORTANT: Never use generic command names like "sleep" as stragglers or
command names with terminate_all_matching_process_names=True — that kills
container infrastructure (tail -f, sleep in docker-init, etc.).
Always use unique markers scoped to the test invocation.
"""

import os
import signal
import subprocess
import time
import uuid

import psutil
import pytest

from tests.utils.managed_process import ManagedProcess

pytestmark = [
    pytest.mark.parallel,
    pytest.mark.gpu_0,
    pytest.mark.unit,
    pytest.mark.pre_merge,
]


def _unique_marker() -> str:
    """Per-call unique marker that won't collide across xdist workers."""
    return f"__mp_test_{uuid.uuid4().hex[:12]}__"


def _pid_alive(pid: int) -> bool:
    """Check whether a PID is still running (zombies count as dead)."""
    try:
        p = psutil.Process(pid)
        return p.status() != psutil.STATUS_ZOMBIE
    except (psutil.NoSuchProcess, psutil.AccessDenied):
        return False


def _wait_for_pid_death(pid: int, timeout: float = 10.0) -> bool:
    """Poll until PID is dead or timeout. Returns True if dead."""
    deadline = time.monotonic() + timeout
    while time.monotonic() < deadline:
        if not _pid_alive(pid):
            return True
        time.sleep(0.1)
    return False


def _collect_tree_pids(root_pid: int) -> set[int]:
    """Return {root_pid} union all descendant PIDs."""
    pids = set()
    try:
        parent = psutil.Process(root_pid)
        pids.add(root_pid)
        for child in parent.children(recursive=True):
            pids.add(child.pid)
    except psutil.NoSuchProcess:
        pass
    return pids


def _wait_for_tree(
    root_pid: int, min_count: int, timeout: float = 3.0, poll: float = 0.1
) -> set[int]:
    """Poll until the process tree has at least min_count members."""
    deadline = time.monotonic() + timeout
    pids: set[int] = set()
    while time.monotonic() < deadline:
        pids = _collect_tree_pids(root_pid)
        if len(pids) >= min_count:
            return pids
        time.sleep(poll)
    return pids


def _bash_sleep_cmd(marker: str, tag: str = "") -> list[str]:
    """Return a bash command that sleeps 300s with an embedded unique marker.
    The trailing `: noexit` prevents bash from exec-ing into sleep
    (which would lose the marker from the cmdline)."""
    return ["bash", "-c", f": {marker}{tag}; sleep 300; : noexit"]


# ---------------------------------------------------------------------------
# Scenario 1: Simple process with children — all should die on __exit__
# ---------------------------------------------------------------------------
class TestSimpleProcessTree:
    def test_parent_and_children_killed(self, tmp_path):
        """A parent that forks children; all should be dead after __exit__."""
        marker = _unique_marker()
        mp = ManagedProcess(
            command=[
                "bash",
                "-c",
                f": {marker}; sleep 300 & sleep 300 & wait",
            ],
            timeout=10,
            display_output=False,
            terminate_all_matching_process_names=False,
            log_dir=str(tmp_path),
        )

        with mp:
            assert mp.proc is not None
            root_pid = mp.proc.pid
            tree_pids = _wait_for_tree(root_pid, min_count=2)
            assert len(tree_pids) >= 2, f"Expected parent + children, got {tree_pids}"

        for pid in tree_pids:
            assert _wait_for_pid_death(
                pid, timeout=10
            ), f"PID {pid} still alive after teardown"


# ---------------------------------------------------------------------------
# Scenario 2: Deep process tree (grandchildren)
# ---------------------------------------------------------------------------
class TestDeepProcessTree:
    def test_grandchildren_killed(self, tmp_path):
        """Parent -> child -> grandchild; all should be dead after __exit__."""
        marker = _unique_marker()
        mp = ManagedProcess(
            command=[
                "bash",
                "-c",
                f": {marker}; bash -c 'bash -c \"sleep 300\" & wait' & wait",
            ],
            timeout=10,
            display_output=False,
            terminate_all_matching_process_names=False,
            log_dir=str(tmp_path),
        )

        with mp:
            assert mp.proc is not None
            root_pid = mp.proc.pid
            tree_pids = _wait_for_tree(root_pid, min_count=3)
            assert (
                len(tree_pids) >= 3
            ), f"Expected parent + child + grandchild, got {tree_pids}"

        for pid in tree_pids:
            assert _wait_for_pid_death(
                pid, timeout=10
            ), f"PID {pid} still alive after teardown"


# ---------------------------------------------------------------------------
# Scenario 3: Child creates its own process group (setpgid)
# ---------------------------------------------------------------------------
class TestChildWithOwnProcessGroup:
    def test_child_in_own_pgid_killed(self, tmp_path):
        """A child that calls setpgid(0,0) to leave the parent's group
        should still be killed via the snapshotted pgid set."""
        script = (
            "import os, time; "
            "pid = os.fork(); "
            "_ = (os.setpgid(0, 0), time.sleep(300)) if pid == 0 else "
            "(time.sleep(0.3), time.sleep(300))"
        )
        mp = ManagedProcess(
            command=["python3", "-c", script],
            timeout=10,
            display_output=False,
            terminate_all_matching_process_names=False,
            log_dir=str(tmp_path),
        )

        with mp:
            assert mp.proc is not None
            root_pid = mp.proc.pid
            tree_pids = _wait_for_tree(root_pid, min_count=2)
            assert len(tree_pids) >= 2, f"Expected parent + child, got {tree_pids}"

            child_pids = tree_pids - {root_pid}
            parent_pgid = os.getpgid(root_pid)
            found_separate_pgid = False
            for cpid in child_pids:
                try:
                    if os.getpgid(cpid) != parent_pgid:
                        found_separate_pgid = True
                        break
                except (ProcessLookupError, OSError):
                    pass
            if not found_separate_pgid:
                pytest.skip("Child didn't get a separate pgid (OS-dependent)")

        for pid in tree_pids:
            assert _wait_for_pid_death(
                pid, timeout=10
            ), f"PID {pid} still alive after teardown (separate pgid scenario)"


# ---------------------------------------------------------------------------
# Scenario 4: xdist-safe mode skips _cleanup_stragglers
# ---------------------------------------------------------------------------
class TestXdistSafeSkipsStragglers:
    def test_stragglers_not_killed_in_xdist_mode(self, tmp_path):
        """With terminate_all_matching_process_names=False, _cleanup_stragglers
        should NOT kill unrelated processes matching the straggler pattern."""
        marker = _unique_marker()
        bystander = subprocess.Popen(
            _bash_sleep_cmd(marker, "bystander"),
            start_new_session=True,
        )
        bystander_pid = bystander.pid

        try:
            mp = ManagedProcess(
                command=_bash_sleep_cmd(marker, "main"),
                timeout=10,
                display_output=False,
                terminate_all_matching_process_names=False,
                straggler_commands=[marker],
                log_dir=str(tmp_path),
            )

            with mp:
                pass

            assert _pid_alive(
                bystander_pid
            ), "Bystander was killed even though xdist-safe mode was on"
        finally:
            try:
                os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
            except (ProcessLookupError, PermissionError, OSError):
                pass
            # Reap the zombie so it doesn't linger in the process table
            # for the rest of the pytest session.
            try:
                bystander.wait(timeout=2)
            except subprocess.TimeoutExpired:
                try:
                    os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
                except (ProcessLookupError, PermissionError, OSError):
                    pass

    def test_stragglers_killed_when_not_xdist_mode(self, tmp_path):
        """With terminate_all_matching_process_names=True, _cleanup_stragglers
        SHOULD kill processes matching the straggler pattern."""
        marker = _unique_marker()
        victim_tag = f"{marker}_victim"
        launcher_tag = f"{marker}_launcher"

        bystander = subprocess.Popen(
            ["bash", "-c", f": {victim_tag}; sleep 300; : noexit"],
            start_new_session=True,
        )
        bystander_pid = bystander.pid

        try:
            mp = ManagedProcess(
                command=["bash", "-c", f": {launcher_tag}; sleep 1"],
                timeout=10,
                display_output=False,
                display_name=launcher_tag,
                terminate_all_matching_process_names=True,
                straggler_commands=[victim_tag],
                log_dir=str(tmp_path),
            )

            with mp:
                time.sleep(0.5)

            assert _wait_for_pid_death(
                bystander_pid, timeout=10
            ), "Bystander with matching straggler command should have been killed"
        finally:
            try:
                os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
            except (ProcessLookupError, PermissionError, OSError):
                pass
            # Reap the zombie so it doesn't linger in the process table
            # for the rest of the pytest session.
            try:
                bystander.wait(timeout=2)
            except subprocess.TimeoutExpired:
                try:
                    os.killpg(os.getpgid(bystander_pid), signal.SIGKILL)
                except (ProcessLookupError, PermissionError, OSError):
                    pass


# ---------------------------------------------------------------------------
# Scenario 5: Process already dead before __exit__
# ---------------------------------------------------------------------------
class TestAlreadyDeadProcess:
    def test_exit_handles_dead_process(self, tmp_path):
        """If the process exits on its own before __exit__, teardown should
        not raise."""
        mp = ManagedProcess(
            command=["bash", "-c", "exit 0"],
            timeout=10,
            display_output=False,
            terminate_all_matching_process_names=False,
            log_dir=str(tmp_path),
        )

        with mp:
            time.sleep(0.5)
        # No exception = pass


# ---------------------------------------------------------------------------
# Scenario 6: SIGTERM grace period — process that traps SIGTERM
# ---------------------------------------------------------------------------
class TestSigtermGracePeriod:
    def test_process_gets_sigterm_grace_before_sigkill(self, tmp_path):
        """A process that handles SIGTERM and takes a moment to exit should
        get the grace period, not be immediately SIGKILLed.

        Uses a Python child that writes a "ready" file after installing its
        SIGTERM handler, so we don't race against interpreter startup."""
        marker_file = str(tmp_path / "got_sigterm")
        ready_file = str(tmp_path / "ready")
        script = (
            "import os, signal, time, pathlib; "
            f"marker = pathlib.Path('{marker_file}'); "
            f"ready = pathlib.Path('{ready_file}'); "
            "signal.signal(signal.SIGTERM, "
            "lambda *_: (marker.touch(), os._exit(0))); "
            "ready.touch(); "
            "[time.sleep(0.1) for _ in iter(int, 1)]"
        )
        mp = ManagedProcess(
            command=["python3", "-c", script],
            timeout=10,
            display_output=False,
            terminate_all_matching_process_names=False,
            log_dir=str(tmp_path),
        )

        with mp:
            assert mp.proc is not None
            deadline = time.monotonic() + 5.0
            while not os.path.exists(ready_file):
                assert time.monotonic() < deadline, "Child never became ready"
                time.sleep(0.05)

        assert os.path.exists(
            marker_file
        ), "Process was SIGKILLed before SIGTERM handler could run"