managed_process.py 14.7 KB
Newer Older
Neelay Shah's avatar
Neelay Shah committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import json
Neelay Shah's avatar
Neelay Shah committed
17
18
19
20
21
22
23
24
25
26
27
28
29
import logging
import os
import shutil
import socket
import subprocess
import time
from dataclasses import dataclass, field
from typing import Any, List, Optional

import psutil
import requests


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def terminate_process(process, logger=logging.getLogger(), immediate_kill=False):
    try:
        logger.info("Terminating PID: %s name: %s", process.pid, process.name())
        if immediate_kill:
            logger.info("Sending Kill: %s %s", process.pid, process.name())
            process.kill()
        else:
            process.terminate()
    except psutil.AccessDenied:
        logger.warning("Access denied for PID %s", process.pid)
    except psutil.NoSuchProcess:
        logger.warning("PID %s no longer exists", process.pid)


def terminate_process_tree(
    pid, logger=logging.getLogger(), immediate_kill=False, timeout=10
):
    try:
        parent = psutil.Process(pid)
        for child in parent.children(recursive=True):
            terminate_process(child, logger, immediate_kill)

        terminate_process(parent, logger, immediate_kill)

        for child in parent.children(recursive=True):
            try:
                child.wait(timeout)
            except psutil.TimeoutExpired:
                terminate_process(child, logger, immediate_kill=True)
        try:
            parent.wait(timeout)
        except psutil.TimeoutExpired:
            terminate_process(parent, logger, immediate_kill=True)

    except psutil.NoSuchProcess:
        # Process already terminated
        pass


Neelay Shah's avatar
Neelay Shah committed
69
70
71
72
73
74
@dataclass
class ManagedProcess:
    command: List[str]
    env: Optional[dict] = None
    health_check_ports: List[int] = field(default_factory=list)
    health_check_urls: List[Any] = field(default_factory=list)
75
    delayed_start: int = 0
Neelay Shah's avatar
Neelay Shah committed
76
77
78
79
80
81
    timeout: int = 300
    working_dir: Optional[str] = None
    display_output: bool = False
    data_dir: Optional[str] = None
    terminate_existing: bool = True
    stragglers: List[str] = field(default_factory=list)
82
    straggler_commands: List[str] = field(default_factory=list)
Neelay Shah's avatar
Neelay Shah committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    log_dir: str = os.getcwd()

    _logger = logging.getLogger()
    _command_name = None
    _log_path = None
    _tee_proc = None
    _sed_proc = None

    def __enter__(self):
        try:
            self._logger = logging.getLogger(self.__class__.__name__)
            self._command_name = self.command[0]
            os.makedirs(self.log_dir, exist_ok=True)
            log_name = f"{self._command_name}.log.txt"
            self._log_path = os.path.join(self.log_dir, log_name)

            if self.data_dir:
                self._remove_directory(self.data_dir)

            self._terminate_existing()
            self._start_process()
104
            time.sleep(self.delayed_start)
Neelay Shah's avatar
Neelay Shah committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
            elapsed = self._check_ports(self.timeout)
            self._check_urls(self.timeout - elapsed)

            return self

        except Exception as e:
            self.__exit__(None, None, None)
            raise e

    def __exit__(self, exc_type, exc_val, exc_tb):
        process_list = [self.proc, self._tee_proc, self._sed_proc]
        for process in process_list:
            if process:
                if process.stdout:
                    process.stdout.close()
                if process.stdin:
                    process.stdin.close()
122
                terminate_process_tree(process.pid, self._logger)
Neelay Shah's avatar
Neelay Shah committed
123
124
125
126
127
128
129
                process.wait()
        if self.data_dir:
            self._remove_directory(self.data_dir)

        for ps_process in psutil.process_iter(["name", "cmdline"]):
            try:
                if ps_process.name() in self.stragglers:
130
131
132
133
134
135
136
137
138
139
140
141
142
143
                    self._logger.info(
                        "Terminating Straggler %s %s", ps_process.name(), ps_process.pid
                    )

                    terminate_process_tree(ps_process.pid, self._logger)
                for cmdline in self.straggler_commands:
                    if cmdline in " ".join(ps_process.cmdline()):
                        self._logger.info(
                            "Terminating Straggler Cmdline %s %s %s",
                            ps_process.name(),
                            ps_process.pid,
                            cmdline,
                        )
                        terminate_process_tree(ps_process.pid, self._logger)
144
            except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
Neelay Shah's avatar
Neelay Shah committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
                # Process may have terminated or become inaccessible during iteration
                pass

    def _start_process(self):
        assert self._command_name
        assert self._log_path

        self._logger.info(
            "Running command: %s in %s",
            " ".join(self.command),
            self.working_dir or os.getcwd(),
        )

        stdin = subprocess.DEVNULL
        stdout = subprocess.PIPE
        stderr = subprocess.STDOUT

        if self.display_output:
            self.proc = subprocess.Popen(
                self.command,
                env=self.env or os.environ.copy(),
                cwd=self.working_dir,
                stdin=stdin,
                stdout=stdout,
                stderr=stderr,
170
                start_new_session=True,  # Isolate process group to prevent kill 0 from affecting parent
Neelay Shah's avatar
Neelay Shah committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
            )
            self._sed_proc = subprocess.Popen(
                ["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
                stdin=self.proc.stdout,
                stdout=subprocess.PIPE,
            )

            self._tee_proc = subprocess.Popen(
                ["tee", self._log_path], stdin=self._sed_proc.stdout
            )

        else:
            with open(self._log_path, "w", encoding="utf-8") as f:
                self.proc = subprocess.Popen(
                    self.command,
                    env=self.env or os.environ.copy(),
                    cwd=self.working_dir,
                    stdin=stdin,
                    stdout=stdout,
                    stderr=stderr,
191
                    start_new_session=True,  # Isolate process group to prevent kill 0 from affecting parent
Neelay Shah's avatar
Neelay Shah committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
                )

                self._sed_proc = subprocess.Popen(
                    ["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
                    stdin=self.proc.stdout,
                    stdout=f,
                )
            self._tee_proc = None

    def _remove_directory(self, path: str) -> None:
        """Remove a directory."""
        try:
            shutil.rmtree(path, ignore_errors=True)
        except (OSError, IOError) as e:
            self._logger.warning("Warning: Failed to remove directory %s: %s", path, e)

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    def _log_tail_on_error(self, lines=20):
        """Print the last few lines of the log file when process dies."""
        if self._log_path and os.path.exists(self._log_path):
            try:
                with open(self._log_path, "r") as f:
                    log_lines = f.readlines()
                    if log_lines:
                        self._logger.error(
                            "=== Last %d lines from %s ===",
                            min(lines, len(log_lines)),
                            self._log_path,
                        )
                        for line in log_lines[-lines:]:
                            self._logger.error(line.rstrip())
                        self._logger.error("=== End of log tail ===")
            except Exception as e:
                self._logger.warning("Could not read log file: %s", e)

    def _check_process_alive(self, context=""):
        """Check if the main process is still alive. Raises RuntimeError if dead."""
        if self.proc and self.proc.poll() is not None:
            returncode = self.proc.returncode
            self._logger.error(
                "Main server process died with exit code %d%s",
                returncode,
                f" {context}" if context else "",
            )
            # Try to get last few lines from log for debugging
            self._log_tail_on_error()
            raise RuntimeError(
                f"Main server process exited with code {returncode}{f' {context}' if context else ''}"
            )

Neelay Shah's avatar
Neelay Shah committed
241
242
243
244
245
246
247
248
249
250
251
252
    def _check_ports(self, timeout):
        elapsed = 0.0
        for port in self.health_check_ports:
            elapsed += self._check_port(port, timeout - elapsed)
        return elapsed

    def _check_port(self, port, timeout=30, sleep=0.1):
        """Check if a port is open on localhost."""
        start_time = time.time()
        self._logger.info("Checking Port: %s", port)
        elapsed = 0.0
        while elapsed < timeout:
253
254
255
            # Check if the main process is still alive
            self._check_process_alive(f"while waiting for port {port}")

Neelay Shah's avatar
Neelay Shah committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                if s.connect_ex(("localhost", port)) == 0:
                    self._logger.info("SUCCESS: Check Port: %s", port)
                    return time.time() - start_time
            time.sleep(sleep)
            elapsed = time.time() - start_time
        self._logger.error("FAILED: Check Port: %s", port)
        raise RuntimeError("FAILED: Check Port: %s" % port)

    def _check_urls(self, timeout):
        elapsed = 0.0
        for url in self.health_check_urls:
            elapsed += self._check_url(url, timeout - elapsed)
        return elapsed

271
    def _check_url(self, url, timeout=30, sleep=1, log_interval=10):
Neelay Shah's avatar
Neelay Shah committed
272
273
274
275
276
277
278
279
        if isinstance(url, tuple):
            response_check = url[1]
            url = url[0]
        else:
            response_check = None
        start_time = time.time()
        self._logger.info("Checking URL %s", url)
        elapsed = 0.0
280
281
282
        attempt = 0
        last_log_time = 0.0

Neelay Shah's avatar
Neelay Shah committed
283
        while elapsed < timeout:
284
285
286
287
288
289
            self._check_process_alive("while waiting for health check")

            attempt += 1
            check_failed = False
            failure_reason = None

Neelay Shah's avatar
Neelay Shah committed
290
291
292
293
            try:
                response = requests.get(url, timeout=timeout - elapsed)
                if response.status_code == 200:
                    if response_check is None or response_check(response):
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
                        # Try to format JSON response nicely, otherwise show raw text
                        try:
                            response_data = response.json()
                            response_str = json.dumps(response_data, indent=2)
                            self._logger.info(
                                "SUCCESS: Check URL: %s (attempt=%d, elapsed=%.1fs)\nResponse:\n%s",
                                url,
                                attempt,
                                elapsed,
                                response_str,
                            )
                        except (json.JSONDecodeError, Exception):
                            # If not JSON or any error, show raw text (truncated if too long)
                            response_text = response.text
                            if len(response_text) > 500:
                                response_text = response_text[:500] + "... (truncated)"
                            self._logger.info(
                                "SUCCESS: Check URL: %s (attempt=%d, elapsed=%.1fs)\nResponse: %s",
                                url,
                                attempt,
                                elapsed,
                                response_text,
                            )
Neelay Shah's avatar
Neelay Shah committed
317
                        return time.time() - start_time
318
319
320
321
322
323
                    else:
                        check_failed = True
                        failure_reason = "custom check failed"
                else:
                    check_failed = True
                    failure_reason = f"status code {response.status_code}"
Neelay Shah's avatar
Neelay Shah committed
324
            except requests.RequestException as e:
325
326
327
328
329
330
331
332
333
334
335
336
337
338
                check_failed = True
                failure_reason = f"request exception: {e}"

            # Log progress every log_interval seconds for any failure
            if check_failed and elapsed - last_log_time >= log_interval:
                self._logger.info(
                    "Still waiting for URL %s (%s) (attempt=%d, elapsed=%.1fs)",
                    url,
                    failure_reason,
                    attempt,
                    elapsed,
                )
                last_log_time = elapsed

Neelay Shah's avatar
Neelay Shah committed
339
340
341
            time.sleep(sleep)
            elapsed = time.time() - start_time

342
343
344
        self._logger.error(
            "FAILED: Check URL: %s (attempts=%d, elapsed=%.1fs)", url, attempt, elapsed
        )
Neelay Shah's avatar
Neelay Shah committed
345
346
347
348
349
        raise RuntimeError("FAILED: Check URL: %s" % url)

    def _terminate_existing(self):
        if self.terminate_existing:
            for proc in psutil.process_iter(["name", "cmdline"]):
350
351
352
353
354
                try:
                    if (
                        proc.name() == self._command_name
                        or proc.name() in self.stragglers
                    ):
355
                        self._logger.info(
356
                            "Terminating Existing %s %s", proc.name(), proc.pid
357
                        )
358

359
                        terminate_process_tree(proc.pid, self._logger)
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
                    for cmdline in self.straggler_commands:
                        if cmdline in " ".join(proc.cmdline()):
                            self._logger.info(
                                "Terminating Existing CmdLine %s %s %s",
                                proc.name(),
                                proc.pid,
                                proc.cmdline(),
                            )
                            terminate_process_tree(proc.pid, self._logger)
                except (
                    psutil.NoSuchProcess,
                    psutil.AccessDenied,
                    psutil.ZombieProcess,
                ):
                    # Process may have terminated or become inaccessible during iteration
                    pass
Neelay Shah's avatar
Neelay Shah committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398


def main():
    with ManagedProcess(
        command=[
            "dynamo",
            "run",
            "in=http",
            "out=vllm",
            "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        ],
        display_output=True,
        terminate_existing=True,
        health_check_ports=[8080],
        health_check_urls=["http://localhost:8080/v1/models"],
        timeout=10,
    ):
        time.sleep(60)
        pass


if __name__ == "__main__":
    main()