engine_process.py 9.03 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
# SPDX-License-Identifier: Apache-2.0

import json
import logging
6
import os
7
import time
8
9
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
10
11
12

import requests

13
from tests.utils.constants import DefaultPort
14
from tests.utils.managed_process import ManagedProcess
15
from tests.utils.payloads import BasePayload, check_health_generate, check_models_api
16
17
18

logger = logging.getLogger(__name__)

19

20
21
22
FRONTEND_PORT = (
    DefaultPort.FRONTEND.value
)  # Do NOT use this in tests! Use allocate_port() instead.
23

24
25
26
27
28
29
30

class EngineResponseError(Exception):
    """Custom exception for engine response errors"""

    pass


31
32
class EngineLogError(Exception):
    """Custom exception for engine log validation errors"""
33

34
    pass
35
36


37
38
39
@dataclass
class EngineConfig:
    """Base configuration for engine test scenarios"""
40

41
42
43
44
45
    name: str
    directory: str
    marks: List[Any]
    request_payloads: List[BasePayload]
    model: str
46

47
48
    script_name: Optional[str] = None
    command: Optional[List[str]] = None
49
    script_args: Optional[List[str]] = None
50
    frontend_port: int = DefaultPort.FRONTEND.value
51
52
    timeout: int = 600
    delayed_start: int = 0
53
    health_check_workers: bool = False
54
55
    env: Dict[str, str] = field(default_factory=dict)
    stragglers: list[str] = field(default_factory=list)
56

57
58
59
60
61
62
63
    def __post_init__(self):
        """Validate that either script_name or command is provided, but not both."""
        if not self.script_name and not self.command:
            raise ValueError("Either script_name or command must be provided")
        if self.script_name and self.command:
            raise ValueError("Cannot provide both script_name and command")

64

65
66
class EngineProcess(ManagedProcess):
    """Base class for LLM engine processes (vLLM, TRT-LLM, etc.)"""
67
68
69

    def check_response(
        self,
70
        payload: BasePayload,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        response: requests.Response,
    ) -> None:
        """
        Check if the response is valid and contains expected content.

        Args:
            payload: The original payload (should have expected_response attribute)
            response: The response object
            response_handler: Function to extract content from response

        Raises:
            EngineResponseError: If the response is invalid or missing expected content
        """

        if response.status_code != 200:
            logger.error(
                "Response returned non-200 status code: %d", response.status_code
            )

            error_msg = f"Response returned non-200 status code: {response.status_code}"
            try:
                error_data = response.json()
                if "error" in error_data:
                    error_msg += f"\nError details: {error_data['error']}"
                logger.error(
                    "Response error details: %s", json.dumps(error_data, indent=2)
                )
            except Exception:
                logger.error("Response text: %s", response.text[:500])

            raise EngineResponseError(error_msg)

        try:
104
105
            content = payload.process_response(response)

106
107
            logger.info(
                "Extracted content: \n%s",
108
109
110
                content[:200] + "..."
                if isinstance(content, str) and len(content) > 200
                else content,
111
            )
112
113
        except AssertionError as e:
            raise EngineResponseError(str(e))
114
        except Exception as e:
115
            raise EngineResponseError(f"Failed to handle response: {e}")
116

117
118
        # Optionally validate expected log patterns after response handling
        if payload.expected_log:
119
120
121
            time.sleep(
                0.5
            )  # The kv event sometimes needs extra time to arrive and be reflected in the log.
122
123
124
125
126
127
128
129
130
131
132
            self.validate_expected_logs(payload.expected_log)

    def validate_expected_logs(self, patterns: Any) -> None:
        """Validate that all regex patterns are present in the current logs.

        Reads the full log via ManagedProcess.read_logs and searches for each
        provided regex pattern. Raises EngineLogError if any are missing.
        """
        import re  # local import to keep module load minimal

        content = self.read_logs() or ""
133
        if not content:
134
135
136
            raise EngineLogError(
                f"Log file not available or empty at path: {self.log_path}"
            )
137

138
139
140
141
142
        compiled = [re.compile(p) for p in patterns]
        missing = []
        for pattern, rx in zip(patterns, compiled):
            if not rx.search(content):
                missing.append(pattern)
143

144
145
146
147
148
149
150
151
        if missing:
            sample = content[-1000:] if len(content) > 1000 else content
            raise EngineLogError(
                f"Missing expected log patterns: {missing}\n\nLog sample:\n{sample}"
            )
        logger.info(f"SUCCESS: All expected log patterns: {patterns} found")

    @classmethod
152
    def from_config(
153
154
155
156
157
        cls,
        config: EngineConfig,
        request: Any,
        extra_env: Optional[Dict[str, str]] = None,
    ) -> "EngineProcess":
158
        """Factory to create an EngineProcess from configuration (script or command)."""
159
160
        assert isinstance(config, EngineConfig), "Must use an instance of EngineConfig"

161
162
163
164
165
166
        if config.script_name:
            command = cls._build_script_command(config)
        elif config.command:
            command = config.command.copy()
        else:
            raise ValueError("Either script_name or command must be provided in config")
167
168
169
170
171
172
173

        env = os.environ.copy()
        if getattr(config, "env", None):
            env.update(config.env)
        if extra_env:
            env.update(extra_env)

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        frontend_checks = [
            (
                f"http://localhost:{config.frontend_port}/v1/models",
                check_models_api,
            ),
            (
                f"http://localhost:{config.frontend_port}/health",
                check_health_generate,
            ),
        ]

        # For disagg-same-gpu deployments, health-check each worker's
        # system port so we wait for ALL workers to be ready, not just the
        # first one to register with the frontend.  Worker liveness checks
        # run FIRST so the frontend has time to discover newly-registered
        # workers before the frontend endpoint checks run.
        #
        # NOTE: DYN_SYSTEM_PORT* env vars are injected by the dynamic port
        # fixtures for ALL tests, so we gate on health_check_workers (only
        # set by same-gpu disagg configs) to avoid health-checking ports
        # that don't serve /health in regular multi-GPU tests.
        delayed = config.delayed_start
        worker_checks: list[tuple] = []
        if config.health_check_workers:
            for key, val in sorted(env.items()):
                if key.startswith("DYN_SYSTEM_PORT") and val.isdigit():
                    worker_checks.append((f"http://localhost:{val}/health", None))
            if worker_checks:
                delayed = 0

        health_urls = worker_checks + frontend_checks

206
207
208
209
210
        return cls(
            command=command,
            env=env,
            timeout=config.timeout,
            display_output=True,
211
            working_dir=config.directory,
212
            health_check_ports=[],
213
214
            health_check_urls=health_urls,
            delayed_start=delayed,
215
216
217
            # Must stay False: command[0] is "bash", so True would kill every
            # bash process system-wide.  Stale cleanup relies on stragglers list
            # and process-group termination in __exit__ instead.
218
            terminate_all_matching_process_names=False,
219
220
221
            stragglers=config.stragglers,
            log_dir=request.node.name,
        )
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

    @classmethod
    def _build_script_command(cls, config: EngineConfig) -> List[str]:
        """Build command from script configuration."""
        assert (
            config.script_name
        ), "Must provide script_name to run fn _build_script_command"
        directory = config.directory
        script_path = os.path.join(directory, "launch", config.script_name)

        if not os.path.exists(script_path):
            raise FileNotFoundError(f"Script not found: {script_path}")

        command: List[str] = ["bash", script_path]
        if config.script_args:
            command.extend(config.script_args)

        return command

    @classmethod
    def from_script(
        cls,
        config: EngineConfig,
        request: Any,
        extra_env: Optional[Dict[str, str]] = None,
    ) -> "EngineProcess":
        """Factory to create an EngineProcess configured to run a launch script.

        Deprecated: Use from_config() instead.
        """
        return cls.from_config(config, request, extra_env)

    @classmethod
    def from_command(
        cls,
        config: EngineConfig,
        request: Any,
        extra_env: Optional[Dict[str, str]] = None,
    ) -> "EngineProcess":
        """Factory to create an EngineProcess configured to run a direct command.

        Deprecated: Use from_config() instead.
        """
        return cls.from_config(config, request, extra_env)