engine_process.py 7.71 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
# SPDX-License-Identifier: Apache-2.0

import json
import logging
6
import os
7
import time
8
9
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
10
11
12

import requests

13
from tests.utils.constants import DefaultPort
14
from tests.utils.managed_process import ManagedProcess
15
from tests.utils.payloads import BasePayload, check_health_generate, check_models_api
16
17
18

logger = logging.getLogger(__name__)

19
20
21
FRONTEND_PORT = (
    DefaultPort.FRONTEND.value
)  # Do NOT use this in tests! Use allocate_port() instead.
22

23
24
25
26
27
28
29

class EngineResponseError(Exception):
    """Custom exception for engine response errors"""

    pass


30
31
class EngineLogError(Exception):
    """Custom exception for engine log validation errors"""
32

33
    pass
34
35


36
37
38
@dataclass
class EngineConfig:
    """Base configuration for engine test scenarios"""
39

40
41
42
43
44
    name: str
    directory: str
    marks: List[Any]
    request_payloads: List[BasePayload]
    model: str
45

46
47
    script_name: Optional[str] = None
    command: Optional[List[str]] = None
48
    script_args: Optional[List[str]] = None
49
    frontend_port: int = DefaultPort.FRONTEND.value
50
51
52
53
    timeout: int = 600
    delayed_start: int = 0
    env: Dict[str, str] = field(default_factory=dict)
    stragglers: list[str] = field(default_factory=list)
54

55
56
57
58
59
60
61
    def __post_init__(self):
        """Validate that either script_name or command is provided, but not both."""
        if not self.script_name and not self.command:
            raise ValueError("Either script_name or command must be provided")
        if self.script_name and self.command:
            raise ValueError("Cannot provide both script_name and command")

62

63
64
class EngineProcess(ManagedProcess):
    """Base class for LLM engine processes (vLLM, TRT-LLM, etc.)"""
65
66
67

    def check_response(
        self,
68
        payload: BasePayload,
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        response: requests.Response,
    ) -> None:
        """
        Check if the response is valid and contains expected content.

        Args:
            payload: The original payload (should have expected_response attribute)
            response: The response object
            response_handler: Function to extract content from response

        Raises:
            EngineResponseError: If the response is invalid or missing expected content
        """

        if response.status_code != 200:
            logger.error(
                "Response returned non-200 status code: %d", response.status_code
            )

            error_msg = f"Response returned non-200 status code: {response.status_code}"
            try:
                error_data = response.json()
                if "error" in error_data:
                    error_msg += f"\nError details: {error_data['error']}"
                logger.error(
                    "Response error details: %s", json.dumps(error_data, indent=2)
                )
            except Exception:
                logger.error("Response text: %s", response.text[:500])

            raise EngineResponseError(error_msg)

        try:
102
103
            content = payload.process_response(response)

104
105
            logger.info(
                "Extracted content: \n%s",
106
107
108
                content[:200] + "..."
                if isinstance(content, str) and len(content) > 200
                else content,
109
            )
110
111
        except AssertionError as e:
            raise EngineResponseError(str(e))
112
        except Exception as e:
113
            raise EngineResponseError(f"Failed to handle response: {e}")
114

115
116
        # Optionally validate expected log patterns after response handling
        if payload.expected_log:
117
118
119
            time.sleep(
                0.5
            )  # The kv event sometimes needs extra time to arrive and be reflected in the log.
120
121
122
123
124
125
126
127
128
129
130
            self.validate_expected_logs(payload.expected_log)

    def validate_expected_logs(self, patterns: Any) -> None:
        """Validate that all regex patterns are present in the current logs.

        Reads the full log via ManagedProcess.read_logs and searches for each
        provided regex pattern. Raises EngineLogError if any are missing.
        """
        import re  # local import to keep module load minimal

        content = self.read_logs() or ""
131
        if not content:
132
133
134
            raise EngineLogError(
                f"Log file not available or empty at path: {self.log_path}"
            )
135

136
137
138
139
140
        compiled = [re.compile(p) for p in patterns]
        missing = []
        for pattern, rx in zip(patterns, compiled):
            if not rx.search(content):
                missing.append(pattern)
141

142
143
144
145
146
147
148
149
        if missing:
            sample = content[-1000:] if len(content) > 1000 else content
            raise EngineLogError(
                f"Missing expected log patterns: {missing}\n\nLog sample:\n{sample}"
            )
        logger.info(f"SUCCESS: All expected log patterns: {patterns} found")

    @classmethod
150
    def from_config(
151
152
153
154
155
        cls,
        config: EngineConfig,
        request: Any,
        extra_env: Optional[Dict[str, str]] = None,
    ) -> "EngineProcess":
156
        """Factory to create an EngineProcess from configuration (script or command)."""
157
158
        assert isinstance(config, EngineConfig), "Must use an instance of EngineConfig"

159
160
161
162
163
164
        if config.script_name:
            command = cls._build_script_command(config)
        elif config.command:
            command = config.command.copy()
        else:
            raise ValueError("Either script_name or command must be provided in config")
165
166
167
168
169
170
171
172
173
174
175
176

        env = os.environ.copy()
        if getattr(config, "env", None):
            env.update(config.env)
        if extra_env:
            env.update(extra_env)

        return cls(
            command=command,
            env=env,
            timeout=config.timeout,
            display_output=True,
177
            working_dir=config.directory,
178
179
180
            health_check_ports=[],
            health_check_urls=[
                (
181
182
183
184
185
                    f"http://localhost:{config.frontend_port}/v1/models",
                    check_models_api,
                ),
                (
                    f"http://localhost:{config.frontend_port}/health",
186
187
188
189
190
191
192
193
                    check_health_generate,
                ),
            ],
            delayed_start=config.delayed_start,
            terminate_existing=False,
            stragglers=config.stragglers,
            log_dir=request.node.name,
        )
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

    @classmethod
    def _build_script_command(cls, config: EngineConfig) -> List[str]:
        """Build command from script configuration."""
        assert (
            config.script_name
        ), "Must provide script_name to run fn _build_script_command"
        directory = config.directory
        script_path = os.path.join(directory, "launch", config.script_name)

        if not os.path.exists(script_path):
            raise FileNotFoundError(f"Script not found: {script_path}")

        command: List[str] = ["bash", script_path]
        if config.script_args:
            command.extend(config.script_args)

        return command

    @classmethod
    def from_script(
        cls,
        config: EngineConfig,
        request: Any,
        extra_env: Optional[Dict[str, str]] = None,
    ) -> "EngineProcess":
        """Factory to create an EngineProcess configured to run a launch script.

        Deprecated: Use from_config() instead.
        """
        return cls.from_config(config, request, extra_env)

    @classmethod
    def from_command(
        cls,
        config: EngineConfig,
        request: Any,
        extra_env: Optional[Dict[str, str]] = None,
    ) -> "EngineProcess":
        """Factory to create an EngineProcess configured to run a direct command.

        Deprecated: Use from_config() instead.
        """
        return cls.from_config(config, request, extra_env)