managed_process.py 17 KB
Newer Older
Neelay Shah's avatar
Neelay Shah committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import json
Neelay Shah's avatar
Neelay Shah committed
17
18
19
import logging
import os
import shutil
20
import signal
Neelay Shah's avatar
Neelay Shah committed
21
22
23
24
25
26
27
28
29
30
import socket
import subprocess
import time
from dataclasses import dataclass, field
from typing import Any, List, Optional

import psutil
import requests


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def terminate_process(process, logger=logging.getLogger(), immediate_kill=False):
    try:
        logger.info("Terminating PID: %s name: %s", process.pid, process.name())
        if immediate_kill:
            logger.info("Sending Kill: %s %s", process.pid, process.name())
            process.kill()
        else:
            process.terminate()
    except psutil.AccessDenied:
        logger.warning("Access denied for PID %s", process.pid)
    except psutil.NoSuchProcess:
        logger.warning("PID %s no longer exists", process.pid)


def terminate_process_tree(
    pid, logger=logging.getLogger(), immediate_kill=False, timeout=10
):
    try:
        parent = psutil.Process(pid)
        for child in parent.children(recursive=True):
            terminate_process(child, logger, immediate_kill)

        terminate_process(parent, logger, immediate_kill)

        for child in parent.children(recursive=True):
            try:
                child.wait(timeout)
            except psutil.TimeoutExpired:
                terminate_process(child, logger, immediate_kill=True)
        try:
            parent.wait(timeout)
        except psutil.TimeoutExpired:
            terminate_process(parent, logger, immediate_kill=True)

    except psutil.NoSuchProcess:
        # Process already terminated
        pass


Neelay Shah's avatar
Neelay Shah committed
70
71
72
73
74
75
@dataclass
class ManagedProcess:
    command: List[str]
    env: Optional[dict] = None
    health_check_ports: List[int] = field(default_factory=list)
    health_check_urls: List[Any] = field(default_factory=list)
76
    delayed_start: int = 0
Neelay Shah's avatar
Neelay Shah committed
77
78
79
80
81
82
    timeout: int = 300
    working_dir: Optional[str] = None
    display_output: bool = False
    data_dir: Optional[str] = None
    terminate_existing: bool = True
    stragglers: List[str] = field(default_factory=list)
83
    straggler_commands: List[str] = field(default_factory=list)
Neelay Shah's avatar
Neelay Shah committed
84
85
    log_dir: str = os.getcwd()

86
87
88
89
    # Ensure attributes exist even if startup fails early
    proc: Optional[subprocess.Popen] = None
    _pgid: Optional[int] = None

Neelay Shah's avatar
Neelay Shah committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    _logger = logging.getLogger()
    _command_name = None
    _log_path = None
    _tee_proc = None
    _sed_proc = None

    def __enter__(self):
        try:
            self._logger = logging.getLogger(self.__class__.__name__)
            self._command_name = self.command[0]
            os.makedirs(self.log_dir, exist_ok=True)
            log_name = f"{self._command_name}.log.txt"
            self._log_path = os.path.join(self.log_dir, log_name)

            if self.data_dir:
                self._remove_directory(self.data_dir)

            self._terminate_existing()
            self._start_process()
109
            time.sleep(self.delayed_start)
Neelay Shah's avatar
Neelay Shah committed
110
111
112
113
114
            elapsed = self._check_ports(self.timeout)
            self._check_urls(self.timeout - elapsed)

            return self

115
116
117
118
119
120
121
122
        except Exception:
            try:
                self.__exit__(None, None, None)
            except Exception as cleanup_err:
                self._logger.warning(
                    "Error during cleanup in __enter__: %s", cleanup_err
                )
            raise
Neelay Shah's avatar
Neelay Shah committed
123
124

    def __exit__(self, exc_type, exc_val, exc_tb):
125
126
        self._terminate_process_group()

Neelay Shah's avatar
Neelay Shah committed
127
128
129
        process_list = [self.proc, self._tee_proc, self._sed_proc]
        for process in process_list:
            if process:
130
131
132
133
134
135
136
137
138
                try:
                    if process.stdout:
                        process.stdout.close()
                    if process.stdin:
                        process.stdin.close()
                    terminate_process_tree(process.pid, self._logger)
                    process.wait()
                except Exception as e:
                    self._logger.warning("Error terminating process: %s", e)
Neelay Shah's avatar
Neelay Shah committed
139
140
141
142
143
144
        if self.data_dir:
            self._remove_directory(self.data_dir)

        for ps_process in psutil.process_iter(["name", "cmdline"]):
            try:
                if ps_process.name() in self.stragglers:
145
146
147
148
149
150
151
152
153
154
155
156
157
158
                    self._logger.info(
                        "Terminating Straggler %s %s", ps_process.name(), ps_process.pid
                    )

                    terminate_process_tree(ps_process.pid, self._logger)
                for cmdline in self.straggler_commands:
                    if cmdline in " ".join(ps_process.cmdline()):
                        self._logger.info(
                            "Terminating Straggler Cmdline %s %s %s",
                            ps_process.name(),
                            ps_process.pid,
                            cmdline,
                        )
                        terminate_process_tree(ps_process.pid, self._logger)
159
            except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
Neelay Shah's avatar
Neelay Shah committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
                # Process may have terminated or become inaccessible during iteration
                pass

    def _start_process(self):
        assert self._command_name
        assert self._log_path

        self._logger.info(
            "Running command: %s in %s",
            " ".join(self.command),
            self.working_dir or os.getcwd(),
        )

        stdin = subprocess.DEVNULL
        stdout = subprocess.PIPE
        stderr = subprocess.STDOUT

        if self.display_output:
            self.proc = subprocess.Popen(
                self.command,
                env=self.env or os.environ.copy(),
                cwd=self.working_dir,
                stdin=stdin,
                stdout=stdout,
                stderr=stderr,
185
                start_new_session=True,  # Isolate process group to prevent kill 0 from affecting parent
Neelay Shah's avatar
Neelay Shah committed
186
            )
187
188
189
190
191
192
            # Capture the child's process group id for robust cleanup even if parent shell exits
            try:
                self._pgid = os.getpgid(self.proc.pid)
            except Exception as e:
                self._logger.warning("Could not get process group id: %s", e)
                self._pgid = None
Neelay Shah's avatar
Neelay Shah committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
            self._sed_proc = subprocess.Popen(
                ["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
                stdin=self.proc.stdout,
                stdout=subprocess.PIPE,
            )

            self._tee_proc = subprocess.Popen(
                ["tee", self._log_path], stdin=self._sed_proc.stdout
            )

        else:
            with open(self._log_path, "w", encoding="utf-8") as f:
                self.proc = subprocess.Popen(
                    self.command,
                    env=self.env or os.environ.copy(),
                    cwd=self.working_dir,
                    stdin=stdin,
                    stdout=stdout,
                    stderr=stderr,
212
                    start_new_session=True,  # Isolate process group to prevent kill 0 from affecting parent
Neelay Shah's avatar
Neelay Shah committed
213
                )
214
215
216
217
218
219
                # Capture the child's process group id for robust cleanup even if parent shell exits
                try:
                    self._pgid = os.getpgid(self.proc.pid)
                except Exception as e:
                    self._logger.warning("Could not get process group id: %s", e)
                    self._pgid = None
Neelay Shah's avatar
Neelay Shah committed
220
221
222
223
224
225
226
227

                self._sed_proc = subprocess.Popen(
                    ["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
                    stdin=self.proc.stdout,
                    stdout=f,
                )
            self._tee_proc = None

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    def _terminate_process_group(self, timeout: float = 5.0):
        """Terminate the entire process group/session started for the child.

        This catches cases where the launcher shell exits and its children are reparented,
        leaving no parent PID to traverse, but they remain in the same process group.
        """
        if self._pgid is None:
            return
        try:
            self._logger.info("Terminating process group: %s", self._pgid)
            os.killpg(self._pgid, signal.SIGTERM)
        except ProcessLookupError:
            return
        except Exception as e:
            self._logger.warning(
                "Error sending SIGTERM to process group %s: %s", self._pgid, e
            )
            return

        # Give processes a brief moment to exit gracefully
        time.sleep(timeout)

        # Force kill if anything remains
        try:
            os.killpg(self._pgid, signal.SIGKILL)
        except ProcessLookupError:
            pass
        except Exception as e:
            self._logger.warning(
                "Error sending SIGKILL to process group %s: %s", self._pgid, e
            )

Neelay Shah's avatar
Neelay Shah committed
260
261
262
263
264
265
266
    def _remove_directory(self, path: str) -> None:
        """Remove a directory."""
        try:
            shutil.rmtree(path, ignore_errors=True)
        except (OSError, IOError) as e:
            self._logger.warning("Warning: Failed to remove directory %s: %s", path, e)

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    def _log_tail_on_error(self, lines=20):
        """Print the last few lines of the log file when process dies."""
        if self._log_path and os.path.exists(self._log_path):
            try:
                with open(self._log_path, "r") as f:
                    log_lines = f.readlines()
                    if log_lines:
                        self._logger.error(
                            "=== Last %d lines from %s ===",
                            min(lines, len(log_lines)),
                            self._log_path,
                        )
                        for line in log_lines[-lines:]:
                            self._logger.error(line.rstrip())
                        self._logger.error("=== End of log tail ===")
            except Exception as e:
                self._logger.warning("Could not read log file: %s", e)

    def _check_process_alive(self, context=""):
        """Check if the main process is still alive. Raises RuntimeError if dead."""
        if self.proc and self.proc.poll() is not None:
            returncode = self.proc.returncode
            self._logger.error(
                "Main server process died with exit code %d%s",
                returncode,
                f" {context}" if context else "",
            )
            # Try to get last few lines from log for debugging
            self._log_tail_on_error()
            raise RuntimeError(
                f"Main server process exited with code {returncode}{f' {context}' if context else ''}"
            )

Neelay Shah's avatar
Neelay Shah committed
300
301
302
303
304
305
306
307
308
309
310
311
    def _check_ports(self, timeout):
        elapsed = 0.0
        for port in self.health_check_ports:
            elapsed += self._check_port(port, timeout - elapsed)
        return elapsed

    def _check_port(self, port, timeout=30, sleep=0.1):
        """Check if a port is open on localhost."""
        start_time = time.time()
        self._logger.info("Checking Port: %s", port)
        elapsed = 0.0
        while elapsed < timeout:
312
313
314
            # Check if the main process is still alive
            self._check_process_alive(f"while waiting for port {port}")

Neelay Shah's avatar
Neelay Shah committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                if s.connect_ex(("localhost", port)) == 0:
                    self._logger.info("SUCCESS: Check Port: %s", port)
                    return time.time() - start_time
            time.sleep(sleep)
            elapsed = time.time() - start_time
        self._logger.error("FAILED: Check Port: %s", port)
        raise RuntimeError("FAILED: Check Port: %s" % port)

    def _check_urls(self, timeout):
        elapsed = 0.0
        for url in self.health_check_urls:
            elapsed += self._check_url(url, timeout - elapsed)
        return elapsed

330
    def _check_url(self, url, timeout=30, sleep=1, log_interval=10):
Neelay Shah's avatar
Neelay Shah committed
331
332
333
334
335
336
337
338
        if isinstance(url, tuple):
            response_check = url[1]
            url = url[0]
        else:
            response_check = None
        start_time = time.time()
        self._logger.info("Checking URL %s", url)
        elapsed = 0.0
339
340
341
        attempt = 0
        last_log_time = 0.0

Neelay Shah's avatar
Neelay Shah committed
342
        while elapsed < timeout:
343
344
345
346
347
348
            self._check_process_alive("while waiting for health check")

            attempt += 1
            check_failed = False
            failure_reason = None

Neelay Shah's avatar
Neelay Shah committed
349
350
351
352
            try:
                response = requests.get(url, timeout=timeout - elapsed)
                if response.status_code == 200:
                    if response_check is None or response_check(response):
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
                        # Try to format JSON response nicely, otherwise show raw text
                        try:
                            response_data = response.json()
                            response_str = json.dumps(response_data, indent=2)
                            self._logger.info(
                                "SUCCESS: Check URL: %s (attempt=%d, elapsed=%.1fs)\nResponse:\n%s",
                                url,
                                attempt,
                                elapsed,
                                response_str,
                            )
                        except (json.JSONDecodeError, Exception):
                            # If not JSON or any error, show raw text (truncated if too long)
                            response_text = response.text
                            if len(response_text) > 500:
                                response_text = response_text[:500] + "... (truncated)"
                            self._logger.info(
                                "SUCCESS: Check URL: %s (attempt=%d, elapsed=%.1fs)\nResponse: %s",
                                url,
                                attempt,
                                elapsed,
                                response_text,
                            )
Neelay Shah's avatar
Neelay Shah committed
376
                        return time.time() - start_time
377
378
379
380
381
382
                    else:
                        check_failed = True
                        failure_reason = "custom check failed"
                else:
                    check_failed = True
                    failure_reason = f"status code {response.status_code}"
Neelay Shah's avatar
Neelay Shah committed
383
            except requests.RequestException as e:
384
385
386
387
388
389
390
391
392
393
394
395
396
397
                check_failed = True
                failure_reason = f"request exception: {e}"

            # Log progress every log_interval seconds for any failure
            if check_failed and elapsed - last_log_time >= log_interval:
                self._logger.info(
                    "Still waiting for URL %s (%s) (attempt=%d, elapsed=%.1fs)",
                    url,
                    failure_reason,
                    attempt,
                    elapsed,
                )
                last_log_time = elapsed

Neelay Shah's avatar
Neelay Shah committed
398
399
400
            time.sleep(sleep)
            elapsed = time.time() - start_time

401
402
403
        self._logger.error(
            "FAILED: Check URL: %s (attempts=%d, elapsed=%.1fs)", url, attempt, elapsed
        )
Neelay Shah's avatar
Neelay Shah committed
404
405
406
407
408
        raise RuntimeError("FAILED: Check URL: %s" % url)

    def _terminate_existing(self):
        if self.terminate_existing:
            for proc in psutil.process_iter(["name", "cmdline"]):
409
410
411
412
413
                try:
                    if (
                        proc.name() == self._command_name
                        or proc.name() in self.stragglers
                    ):
414
                        self._logger.info(
415
                            "Terminating Existing %s %s", proc.name(), proc.pid
416
                        )
417

418
                        terminate_process_tree(proc.pid, self._logger)
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
                    for cmdline in self.straggler_commands:
                        if cmdline in " ".join(proc.cmdline()):
                            self._logger.info(
                                "Terminating Existing CmdLine %s %s %s",
                                proc.name(),
                                proc.pid,
                                proc.cmdline(),
                            )
                            terminate_process_tree(proc.pid, self._logger)
                except (
                    psutil.NoSuchProcess,
                    psutil.AccessDenied,
                    psutil.ZombieProcess,
                ):
                    # Process may have terminated or become inaccessible during iteration
                    pass
Neelay Shah's avatar
Neelay Shah committed
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457


def main():
    with ManagedProcess(
        command=[
            "dynamo",
            "run",
            "in=http",
            "out=vllm",
            "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        ],
        display_output=True,
        terminate_existing=True,
        health_check_ports=[8080],
        health_check_urls=["http://localhost:8080/v1/models"],
        timeout=10,
    ):
        time.sleep(60)
        pass


if __name__ == "__main__":
    main()