engine_process.py 5.62 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
6
7
8
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
9
10
11
12

import requests

from tests.utils.managed_process import ManagedProcess
13
from tests.utils.payloads import BasePayload, check_health_generate, check_models_api
14
15
16

logger = logging.getLogger(__name__)

17
18
FRONTEND_PORT = 8000

19
20
21
22
23
24
25

class EngineResponseError(Exception):
    """Custom exception for engine response errors"""

    pass


26
27
class EngineLogError(Exception):
    """Custom exception for engine log validation errors"""
28

29
    pass
30
31


32
33
34
@dataclass
class EngineConfig:
    """Base configuration for engine test scenarios"""
35

36
37
38
39
40
41
    name: str
    directory: str
    script_name: str
    marks: List[Any]
    request_payloads: List[BasePayload]
    model: str
42

43
44
45
46
47
48
    script_args: Optional[List[str]] = None
    models_port: int = 8000
    timeout: int = 600
    delayed_start: int = 0
    env: Dict[str, str] = field(default_factory=dict)
    stragglers: list[str] = field(default_factory=list)
49
50


51
52
class EngineProcess(ManagedProcess):
    """Base class for LLM engine processes (vLLM, TRT-LLM, etc.)"""
53
54
55

    def check_response(
        self,
56
        payload: BasePayload,
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        response: requests.Response,
    ) -> None:
        """
        Check if the response is valid and contains expected content.

        Args:
            payload: The original payload (should have expected_response attribute)
            response: The response object
            response_handler: Function to extract content from response

        Raises:
            EngineResponseError: If the response is invalid or missing expected content
        """

        if response.status_code != 200:
            logger.error(
                "Response returned non-200 status code: %d", response.status_code
            )

            error_msg = f"Response returned non-200 status code: {response.status_code}"
            try:
                error_data = response.json()
                if "error" in error_data:
                    error_msg += f"\nError details: {error_data['error']}"
                logger.error(
                    "Response error details: %s", json.dumps(error_data, indent=2)
                )
            except Exception:
                logger.error("Response text: %s", response.text[:500])

            raise EngineResponseError(error_msg)

        try:
90
91
            content = payload.process_response(response)

92
93
            logger.info(
                "Extracted content: \n%s",
94
95
96
                content[:200] + "..."
                if isinstance(content, str) and len(content) > 200
                else content,
97
            )
98
99
        except AssertionError as e:
            raise EngineResponseError(str(e))
100
        except Exception as e:
101
            raise EngineResponseError(f"Failed to handle response: {e}")
102

103
104
105
106
107
108
109
110
111
112
113
114
115
        # Optionally validate expected log patterns after response handling
        if payload.expected_log:
            self.validate_expected_logs(payload.expected_log)

    def validate_expected_logs(self, patterns: Any) -> None:
        """Validate that all regex patterns are present in the current logs.

        Reads the full log via ManagedProcess.read_logs and searches for each
        provided regex pattern. Raises EngineLogError if any are missing.
        """
        import re  # local import to keep module load minimal

        content = self.read_logs() or ""
116
        if not content:
117
118
119
            raise EngineLogError(
                f"Log file not available or empty at path: {self.log_path}"
            )
120

121
122
123
124
125
        compiled = [re.compile(p) for p in patterns]
        missing = []
        for pattern, rx in zip(patterns, compiled):
            if not rx.search(content):
                missing.append(pattern)
126

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        if missing:
            sample = content[-1000:] if len(content) > 1000 else content
            raise EngineLogError(
                f"Missing expected log patterns: {missing}\n\nLog sample:\n{sample}"
            )
        logger.info(f"SUCCESS: All expected log patterns: {patterns} found")

    @classmethod
    def from_script(
        cls,
        config: EngineConfig,
        request: Any,
        extra_env: Optional[Dict[str, str]] = None,
    ) -> "EngineProcess":
        """Factory to create an EngineProcess configured to run a launch script."""
        assert isinstance(config, EngineConfig), "Must use an instance of EngineConfig"

        directory = config.directory
        script_path = os.path.join(directory, "launch", config.script_name)

        if not os.path.exists(script_path):
            raise FileNotFoundError(f"Script not found: {script_path}")

        command: List[str] = ["bash", script_path]
        if config.script_args:
            command.extend(config.script_args)

        env = os.environ.copy()
        if getattr(config, "env", None):
            env.update(config.env)
        if extra_env:
            env.update(extra_env)

        return cls(
            command=command,
            env=env,
            timeout=config.timeout,
            display_output=True,
            working_dir=directory,
            health_check_ports=[],
            health_check_urls=[
                (f"http://localhost:{config.models_port}/v1/models", check_models_api),
                (
                    f"http://localhost:{config.models_port}/health",
                    check_health_generate,
                ),
            ],
            delayed_start=config.delayed_start,
            terminate_existing=False,
            stragglers=config.stragglers,
            log_dir=request.node.name,
        )