engine_process.py 7.34 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
6
7
8
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
9
10
11
12

import requests

from tests.utils.managed_process import ManagedProcess
13
from tests.utils.payloads import BasePayload, check_health_generate, check_models_api
14
15
16

logger = logging.getLogger(__name__)

17
18
FRONTEND_PORT = 8000

19
20
21
22
23
24
25

class EngineResponseError(Exception):
    """Custom exception for engine response errors"""

    pass


26
27
class EngineLogError(Exception):
    """Custom exception for engine log validation errors"""
28

29
    pass
30
31


32
33
34
@dataclass
class EngineConfig:
    """Base configuration for engine test scenarios"""
35

36
37
38
39
40
    name: str
    directory: str
    marks: List[Any]
    request_payloads: List[BasePayload]
    model: str
41

42
43
    script_name: Optional[str] = None
    command: Optional[List[str]] = None
44
45
46
47
48
49
    script_args: Optional[List[str]] = None
    models_port: int = 8000
    timeout: int = 600
    delayed_start: int = 0
    env: Dict[str, str] = field(default_factory=dict)
    stragglers: list[str] = field(default_factory=list)
50

51
52
53
54
55
56
57
    def __post_init__(self):
        """Validate that either script_name or command is provided, but not both."""
        if not self.script_name and not self.command:
            raise ValueError("Either script_name or command must be provided")
        if self.script_name and self.command:
            raise ValueError("Cannot provide both script_name and command")

58

59
60
class EngineProcess(ManagedProcess):
    """Base class for LLM engine processes (vLLM, TRT-LLM, etc.)"""
61
62
63

    def check_response(
        self,
64
        payload: BasePayload,
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        response: requests.Response,
    ) -> None:
        """
        Check if the response is valid and contains expected content.

        Args:
            payload: The original payload (should have expected_response attribute)
            response: The response object
            response_handler: Function to extract content from response

        Raises:
            EngineResponseError: If the response is invalid or missing expected content
        """

        if response.status_code != 200:
            logger.error(
                "Response returned non-200 status code: %d", response.status_code
            )

            error_msg = f"Response returned non-200 status code: {response.status_code}"
            try:
                error_data = response.json()
                if "error" in error_data:
                    error_msg += f"\nError details: {error_data['error']}"
                logger.error(
                    "Response error details: %s", json.dumps(error_data, indent=2)
                )
            except Exception:
                logger.error("Response text: %s", response.text[:500])

            raise EngineResponseError(error_msg)

        try:
98
99
            content = payload.process_response(response)

100
101
            logger.info(
                "Extracted content: \n%s",
102
103
104
                content[:200] + "..."
                if isinstance(content, str) and len(content) > 200
                else content,
105
            )
106
107
        except AssertionError as e:
            raise EngineResponseError(str(e))
108
        except Exception as e:
109
            raise EngineResponseError(f"Failed to handle response: {e}")
110

111
112
113
114
115
116
117
118
119
120
121
122
123
        # Optionally validate expected log patterns after response handling
        if payload.expected_log:
            self.validate_expected_logs(payload.expected_log)

    def validate_expected_logs(self, patterns: Any) -> None:
        """Validate that all regex patterns are present in the current logs.

        Reads the full log via ManagedProcess.read_logs and searches for each
        provided regex pattern. Raises EngineLogError if any are missing.
        """
        import re  # local import to keep module load minimal

        content = self.read_logs() or ""
124
        if not content:
125
126
127
            raise EngineLogError(
                f"Log file not available or empty at path: {self.log_path}"
            )
128

129
130
131
132
133
        compiled = [re.compile(p) for p in patterns]
        missing = []
        for pattern, rx in zip(patterns, compiled):
            if not rx.search(content):
                missing.append(pattern)
134

135
136
137
138
139
140
141
142
        if missing:
            sample = content[-1000:] if len(content) > 1000 else content
            raise EngineLogError(
                f"Missing expected log patterns: {missing}\n\nLog sample:\n{sample}"
            )
        logger.info(f"SUCCESS: All expected log patterns: {patterns} found")

    @classmethod
143
    def from_config(
144
145
146
147
148
        cls,
        config: EngineConfig,
        request: Any,
        extra_env: Optional[Dict[str, str]] = None,
    ) -> "EngineProcess":
149
        """Factory to create an EngineProcess from configuration (script or command)."""
150
151
        assert isinstance(config, EngineConfig), "Must use an instance of EngineConfig"

152
153
154
155
156
157
        if config.script_name:
            command = cls._build_script_command(config)
        elif config.command:
            command = config.command.copy()
        else:
            raise ValueError("Either script_name or command must be provided in config")
158
159
160
161
162
163
164
165
166
167
168
169

        env = os.environ.copy()
        if getattr(config, "env", None):
            env.update(config.env)
        if extra_env:
            env.update(extra_env)

        return cls(
            command=command,
            env=env,
            timeout=config.timeout,
            display_output=True,
170
            working_dir=config.directory,
171
172
173
174
175
176
177
178
179
180
181
182
183
            health_check_ports=[],
            health_check_urls=[
                (f"http://localhost:{config.models_port}/v1/models", check_models_api),
                (
                    f"http://localhost:{config.models_port}/health",
                    check_health_generate,
                ),
            ],
            delayed_start=config.delayed_start,
            terminate_existing=False,
            stragglers=config.stragglers,
            log_dir=request.node.name,
        )
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

    @classmethod
    def _build_script_command(cls, config: EngineConfig) -> List[str]:
        """Build command from script configuration."""
        assert (
            config.script_name
        ), "Must provide script_name to run fn _build_script_command"
        directory = config.directory
        script_path = os.path.join(directory, "launch", config.script_name)

        if not os.path.exists(script_path):
            raise FileNotFoundError(f"Script not found: {script_path}")

        command: List[str] = ["bash", script_path]
        if config.script_args:
            command.extend(config.script_args)

        return command

    @classmethod
    def from_script(
        cls,
        config: EngineConfig,
        request: Any,
        extra_env: Optional[Dict[str, str]] = None,
    ) -> "EngineProcess":
        """Factory to create an EngineProcess configured to run a launch script.

        Deprecated: Use from_config() instead.
        """
        return cls.from_config(config, request, extra_env)

    @classmethod
    def from_command(
        cls,
        config: EngineConfig,
        request: Any,
        extra_env: Optional[Dict[str, str]] = None,
    ) -> "EngineProcess":
        """Factory to create an EngineProcess configured to run a direct command.

        Deprecated: Use from_config() instead.
        """
        return cls.from_config(config, request, extra_env)