payloads.py 53.6 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
17
import math
18
import re
Alec's avatar
Alec committed
19
import time
20
from copy import deepcopy
21
from dataclasses import dataclass, field
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
22
from typing import Any, Callable, Dict, List, Optional, cast
23

24
25
import requests

26
from dynamo import prometheus_names  # type: ignore[attr-defined]
27
from tests.utils.constants import DefaultPort
28

29
30
31
32
33
34
35
36
logger = logging.getLogger(__name__)


@dataclass
class BasePayload:
    """Generic payload body plus expectations and repeat count."""

    body: Dict[str, Any]
37
    expected_response: List[Any]  # Can be List[str] or List[List[str]] for alternatives
38
39
    expected_log: List[str]
    repeat_count: int = 1
40
    timeout: int = 60
41
42
43

    # Connection info
    host: str = "localhost"
44
    port: int = DefaultPort.FRONTEND.value
45
46
    endpoint: str = ""
    method: str = "POST"
47
48
49
    # Optional additional ports used by specialized payloads (e.g. LoRA system/control-plane APIs).
    # This is intentionally empty by default to preserve prior semantics.
    system_ports: list[int] = field(default_factory=list)
50
51
    # When True, the HTTP request is made with stream=True (for SSE responses).
    http_stream: bool = False
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    def url(self) -> str:
        ep = self.endpoint.lstrip("/")
        return f"http://{self.host}:{self.port}/{ep}"

    def with_model(self, model):
        p = deepcopy(self)
        if "model" not in p.body:
            p.body = {**p.body, "model": model}
        return p

    def response_handler(self, response: Any) -> str:
        """Extract a text representation of the response for logging/validation."""
        raise NotImplementedError("Subclasses must implement response_handler()")

    def validate(self, response: Any, content: str) -> None:
68
69
70
71
72
        """Default validation: ensure expected substrings appear in content.

        If expected_response is a list of strings, ANY one of them matching is sufficient (OR logic).
        This allows flexible validation where responses may vary but should contain at least one keyword.
        """
73
        if self.expected_response:
74
75
76
77
78
79
80
81
82
            # Check if content is empty
            if not content:
                logger.error("VALIDATION FAILED - Response content is empty")
                raise AssertionError(
                    f"Expected content not found in response. Expected any of: {self.expected_response}. Actual content is empty."
                )

            # Check if ANY of the expected strings are found (OR logic) and count matches
            found_keywords = []
83
            for expected in self.expected_response:
84
85
86
87
88
89
90
91
92
93
94
                if isinstance(expected, str) and expected.lower() in content.lower():
                    found_keywords.append(expected)

            if not found_keywords:
                logger.error(
                    f"VALIDATION FAILED - Actual content returned: {repr(content)}"
                )
                logger.error(
                    f"Expected to find at least one of: {self.expected_response}"
                )
                logger.error(f"Matches found: 0/{len(self.expected_response)}")
95
                raise AssertionError(
96
                    f"Expected content not found in response. Expected at least one of: {self.expected_response}. Actual content: {repr(content)}"
97
                )
98
99
100
101

            logger.info(
                f"SUCCESS: Found {len(found_keywords)}/{len(self.expected_response)} expected keywords: {found_keywords}"
            )
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

    def process_response(self, response: Any) -> str:
        """Convenience: run response_handler then validate; return content."""
        content = self.response_handler(response)
        self.validate(response, content)
        return content


@dataclass
class ChatPayload(BasePayload):
    """Payload for chat completions endpoint."""

    endpoint: str = "/v1/chat/completions"

    @staticmethod
    def extract_content(response):
        """
        Process chat completions API responses.
        """
        response.raise_for_status()
        result = response.json()
123
124
125
126

        assert (
            "choices" in result
        ), f"Missing 'choices' in response. Response keys: {list(result.keys())}"
127
        assert len(result["choices"]) > 0, "Empty choices in response"
128
129
130
        assert (
            "message" in result["choices"][0]
        ), f"Missing 'message' in first choice. Choice keys: {list(result['choices'][0].keys())}"
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

        # Check for content in all possible fields where parsers might put output:
        # 1. content - standard message content
        # 2. reasoning_content - for models with reasoning parsers
        # 3. refusal - when the model refuses to answer
        # 4. tool_calls - for function/tool calling responses

        message = result["choices"][0]["message"]

        content = message.get("content", "")
        reasoning_content = message.get("reasoning_content", "")
        refusal = message.get("refusal", "")

        tool_calls = message.get("tool_calls", [])
        tool_content = ""
        if tool_calls:
            tool_content = ", ".join(
                call.get("function", {}).get("arguments", "")
                for call in tool_calls
                if call.get("function", {}).get("arguments")
            )

        for field_content in [content, reasoning_content, refusal, tool_content]:
            if field_content:
                return field_content

        raise ValueError(
            "All possible content fields are empty in message. "
            f"Checked: content={repr(content)}, reasoning_content={repr(reasoning_content)}, "
            f"refusal={repr(refusal)}, tool_calls={tool_calls}"
        )

    def response_handler(self, response: Any) -> str:
        return ChatPayload.extract_content(response)


167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
@dataclass
class ChatPayloadWithLogprobs(ChatPayload):
    """Chat payload that validates logprobs in response."""

    def validate(self, response: Any, content: str) -> None:
        """Validate response contains logprobs fields."""
        super().validate(response, content)

        result = response.json()
        choice = result["choices"][0]

        # Validate logprobs field exists
        assert "logprobs" in choice, "Missing 'logprobs' in choice"

        logprobs_data = choice["logprobs"]
        if logprobs_data is not None:
            assert "content" in logprobs_data, "Missing 'content' in logprobs"
            content_logprobs = logprobs_data["content"]

            if content_logprobs:
                # Validate structure of logprobs
                for item in content_logprobs:
                    assert "token" in item, "Missing 'token' in logprobs content"
                    assert "logprob" in item, "Missing 'logprob' in logprobs content"
                    assert (
                        "top_logprobs" in item
                    ), "Missing 'top_logprobs' in logprobs content"

195
196
197
198
199
200
201
202
                    # Sanity check: logprob should be valid (not nan/inf/positive)
                    logprob_val = item["logprob"]
                    assert not math.isnan(logprob_val), "logprob is NaN"
                    assert not math.isinf(logprob_val), "logprob is infinite"
                    assert (
                        logprob_val <= 0
                    ), f"logprob should be <= 0, got {logprob_val}"

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
                    # Validate bytes field is populated for the selected token
                    assert "bytes" in item, "Missing 'bytes' in logprobs content item"
                    token_str = item["token"]
                    if token_str:
                        assert (
                            item["bytes"] is not None
                        ), f"'bytes' should be populated for non-empty token {token_str!r}"
                        assert isinstance(
                            item["bytes"], list
                        ), f"'bytes' should be a list, got {type(item['bytes'])}"

                    # Validate top_logprobs entries have token, logprob, and bytes
                    for top_lp in item["top_logprobs"]:
                        assert (
                            "token" in top_lp
                        ), "Missing 'token' in top_logprobs entry"
                        assert (
                            "logprob" in top_lp
                        ), "Missing 'logprob' in top_logprobs entry"
                        assert (
                            "bytes" in top_lp
                        ), "Missing 'bytes' in top_logprobs entry"
                        if top_lp["token"]:
                            assert (
                                top_lp["bytes"] is not None
                            ), f"'bytes' should be populated for top_logprob token {top_lp['token']!r}"

230
231
232
233
234
                logger.info(
                    f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs"
                )


235
236
237
238
239
240
241
242
243
@dataclass
class ToolCallingChatPayload(ChatPayload):
    """ChatPayload that validates tool calls in the response."""

    def __init__(self, *args, expected_tool_name: Optional[str] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.expected_tool_name = expected_tool_name

    def validate(self, response, content: str) -> None:
244
        """Validate that tool calls exist in the response.
245

246
247
248
249
250
251
        Skips the parent's expected_response substring check because tool call
        responses produce structured JSON arguments, not natural-language text.
        The expected_response keywords are instead matched against the
        concatenated tool call arguments so callers can still assert that the
        model "understood" the input (e.g. expected_response=["purple"]).
        """
252
253
254
255
256
257
258
259
260
261
262
        response_data = response.json()
        choices = response_data.get("choices", [])
        assert choices, "Response missing choices"

        message = choices[0].get("message", {})
        tool_calls = message.get("tool_calls", [])

        assert tool_calls, "Expected model to generate tool calls but none found"
        logger.info(f"Tool calls detected: {len(tool_calls)} call(s)")

        # Validate tool call structure
263
        all_args = []
264
265
266
267
268
        for i, tc in enumerate(tool_calls):
            assert "function" in tc, f"Tool call {i} missing 'function' field"
            function = tc.get("function", {})
            assert "name" in function, f"Tool call {i} missing function name"
            assert "arguments" in function, f"Tool call {i} missing function arguments"
269
270
            args_str = function.get("arguments", "")
            all_args.append(args_str)
271
            logger.info(
272
                f"  [{i}] Function: {function.get('name')}, Args: {args_str[:100]}..."
273
274
275
276
277
278
279
280
281
282
            )

        # If expected tool name is provided, validate it
        if self.expected_tool_name:
            tool_names = [tc.get("function", {}).get("name") for tc in tool_calls]
            assert (
                self.expected_tool_name in tool_names
            ), f"Expected tool '{self.expected_tool_name}' not found. Available tools: {tool_names}"
            logger.info(f"Expected tool '{self.expected_tool_name}' was called")

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        # Check expected_response keywords against tool call arguments (OR logic)
        if self.expected_response:
            combined_args = " ".join(all_args).lower()
            found = [kw for kw in self.expected_response if kw.lower() in combined_args]
            if not found:
                logger.error(
                    f"VALIDATION FAILED - Expected to find at least one of "
                    f"{self.expected_response} in tool call arguments"
                )
                logger.error(f"Tool call arguments: {combined_args}")
                raise AssertionError(
                    f"Expected content not found in tool call arguments. "
                    f"Expected at least one of: {self.expected_response}. "
                    f"Tool call arguments: {combined_args}"
                )
            else:
                logger.info(f"Found expected keywords in tool args: {found}")

301

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
@dataclass
class CachedTokensChatPayload(ChatPayload):
    """
    Chat payload that validates cached tokens are populated in repeated requests.

    Used for testing KV router cache-aware routing where repeated identical prompts
    should result in cached tokens being reported in the usage field.

    Validates that usage.prompt_tokens_details.cached_tokens > 0 for requests
    after the first one (since identical prompts should hit the prefix cache).
    """

    def __init__(
        self,
        body: dict,
        repeat_count: int = 3,
        expected_response: Optional[List[str]] = None,
        expected_log: Optional[List[str]] = None,
        timeout: int = 60,
        min_cached_tokens: int = 1,
    ):
        super().__init__(
            body=body,
            repeat_count=repeat_count,
            expected_response=expected_response or [],
            expected_log=expected_log or [],
            timeout=timeout,
        )
        self.min_cached_tokens = min_cached_tokens
        self._request_count = 0
        self._cached_tokens_found = False

    def validate(self, response: Any, content: str) -> None:
        """Validate response and check for cached tokens on repeated requests."""
        # First run the standard content validation
        super().validate(response, content)

        self._request_count += 1
        result = response.json()

        # Check usage field for cached tokens
        # Expected structure: usage.prompt_tokens_details.cached_tokens
        usage = result.get("usage", {})
        prompt_tokens_details = usage.get("prompt_tokens_details") or {}
        cached_tokens = prompt_tokens_details.get("cached_tokens", 0) or 0

        logger.info(
            f"Request {self._request_count}: prompt_tokens={usage.get('prompt_tokens')}, "
            f"cached_tokens={cached_tokens}, prompt_tokens_details={prompt_tokens_details}"
        )

        # For requests after the first one, we expect cached tokens > 0
        # (since identical prompts should hit the prefix cache)
        if self._request_count > 1:
            if cached_tokens >= self.min_cached_tokens:
                self._cached_tokens_found = True
                logger.info(
                    f"✓ Request {self._request_count}: Cached tokens validation PASSED - "
                    f"found {cached_tokens} cached tokens (min required: {self.min_cached_tokens})"
                )
            else:
                logger.warning(
                    f"Request {self._request_count}: cached_tokens={cached_tokens} "
                    f"(expected >= {self.min_cached_tokens})"
                )

    def final_validation(self) -> None:
        """Called after all requests are processed to ensure we saw cached tokens.

        Raises AssertionError if cached tokens were not found on any repeated request.
        """
        if self.repeat_count > 1 and not self._cached_tokens_found:
            raise AssertionError(
                f"Expected cached_tokens >= {self.min_cached_tokens} in "
                f"prompt_tokens_details for at least one repeated request, "
                f"but none found after {self._request_count} requests. "
                f"Verify that prefix caching is enabled and working correctly."
            )
        logger.info(
            "✓ Final validation PASSED: cached_tokens found in repeated requests"
        )


385
386
387
388
389
390
391
392
393
394
395
396
397
398
@dataclass
class LoraTestChatPayload(ChatPayload):
    """
    Chat payload that loads a LoRA adapter before sending inference requests.

    This payload first loads the specified LoRA adapter via the system API,
    then sends chat completion requests using the LoRA model.
    """

    def __init__(
        self,
        body: dict,
        lora_name: str,
        s3_uri: str,
399
        system_port: int = DefaultPort.SYSTEM1.value,
400
401
402
403
404
405
406
407
408
409
410
411
        repeat_count: int = 1,
        expected_response: Optional[list] = None,
        expected_log: Optional[list] = None,
        timeout: int = 60,
    ):
        super().__init__(
            body=body,
            repeat_count=repeat_count,
            expected_response=expected_response or [],
            expected_log=expected_log or [],
            timeout=timeout,
        )
412
        self.system_ports = [system_port]
413
414
415
416
417
418
419
420
421
422
423
424
        self.lora_name = lora_name
        self.s3_uri = s3_uri
        self._lora_loaded = False

    def _ensure_lora_loaded(self) -> None:
        """Ensure the LoRA adapter is loaded before making inference requests"""
        if not self._lora_loaded:
            # Import the load_lora_adapter function
            # Note: This import is done here to avoid circular dependencies
            from tests.serve.lora_utils import load_lora_adapter

            load_lora_adapter(
425
                system_port=self.system_ports[0],
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
                lora_name=self.lora_name,
                s3_uri=self.s3_uri,
                timeout=self.timeout,
            )

            # Wait for the LoRA model to appear in /v1/models
            models_url = f"http://{self.host}:{self.port}/v1/models"
            start_time = time.time()

            logger.info(
                f"Waiting for LoRA model '{self.lora_name}' to appear in /v1/models..."
            )

            while time.time() - start_time < self.timeout:
                try:
                    response = requests.get(models_url, timeout=5)
                    if response.status_code == 200:
                        data = response.json()
                        models = data.get("data", [])
                        model_ids = [m.get("id", "") for m in models]

                        if self.lora_name in model_ids:
                            logger.info(
                                f"LoRA model '{self.lora_name}' is now available"
                            )
                            self._lora_loaded = True
                            return

                        logger.debug(
                            f"Available models: {model_ids}, waiting for '{self.lora_name}'..."
                        )
                except requests.RequestException as e:
                    logger.debug(f"Error checking /v1/models: {e}")

                time.sleep(1)

            raise RuntimeError(
                f"Timeout: LoRA model '{self.lora_name}' did not appear in /v1/models within {self.timeout}s"
            )

    def url(self) -> str:
        """Load LoRA before first request, then return URL"""
        self._ensure_lora_loaded()
        return super().url()


472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
@dataclass
class CompletionPayload(BasePayload):
    """Payload for completions endpoint."""

    endpoint: str = "/v1/completions"

    @staticmethod
    def extract_text(response):
        """
        Process completions API responses.
        """
        response.raise_for_status()
        result = response.json()
        assert "choices" in result, "Missing 'choices' in response"
        assert len(result["choices"]) > 0, "Empty choices in response"
        assert "text" in result["choices"][0], "Missing 'text' in first choice"
        return result["choices"][0]["text"]

    def response_handler(self, response: Any) -> str:
        return CompletionPayload.extract_text(response)


494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
@dataclass
class CompletionPayloadWithLogprobs(CompletionPayload):
    """Completion payload that validates logprobs in response."""

    def validate(self, response: Any, content: str) -> None:
        """Validate response contains logprobs fields."""
        super().validate(response, content)

        result = response.json()
        choice = result["choices"][0]

        # Validate logprobs field exists
        assert "logprobs" in choice, "Missing 'logprobs' in choice"

        logprobs_data = choice["logprobs"]
        if logprobs_data is not None:
            assert (
                "token_logprobs" in logprobs_data
            ), "Missing 'token_logprobs' in logprobs"
            assert "tokens" in logprobs_data, "Missing 'tokens' in logprobs"

            token_logprobs = logprobs_data["token_logprobs"]
            tokens = logprobs_data["tokens"]

            if token_logprobs:
                assert len(token_logprobs) == len(
                    tokens
                ), "Mismatch between token_logprobs and tokens length"
522
523
524
525
526
527
528
529
530
531
532
533
534
535

                # Sanity check: each logprob should be valid (not nan/inf/positive)
                for i, logprob_val in enumerate(token_logprobs):
                    if logprob_val is not None:  # First token can be None
                        assert not math.isnan(
                            logprob_val
                        ), f"logprob at index {i} is NaN"
                        assert not math.isinf(
                            logprob_val
                        ), f"logprob at index {i} is infinite"
                        assert (
                            logprob_val <= 0
                        ), f"logprob at index {i} should be <= 0, got {logprob_val}"

536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
                # Validate top_logprobs entries have token, logprob, and bytes when present
                top_logprobs_list = logprobs_data.get("top_logprobs", [])
                for i, token_top_lps in enumerate(top_logprobs_list):
                    if not token_top_lps:
                        continue
                    for top_lp in token_top_lps:
                        assert (
                            "token" in top_lp
                        ), f"Missing 'token' in top_logprobs[{i}] entry"
                        assert (
                            "logprob" in top_lp
                        ), f"Missing 'logprob' in top_logprobs[{i}] entry"
                        assert (
                            "bytes" in top_lp
                        ), f"Missing 'bytes' in top_logprobs[{i}] entry"
                        if top_lp["token"]:
                            assert (
                                top_lp["bytes"] is not None
                            ), f"'bytes' should be populated for top_logprob token {top_lp['token']!r}"

556
557
558
559
560
                logger.info(
                    f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs"
                )


561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
@dataclass
class ResponsesPayload(BasePayload):
    """Payload for the Responses API endpoint (/v1/responses).

    For full compliance testing, use the OpenResponses bun CLI:
      bun run test:compliance --base-url http://localhost:<port>/v1 --api-key test --model <model>
    See https://www.openresponses.org/compliance
    """

    endpoint: str = "/v1/responses"

    @staticmethod
    def extract_content(response):
        """Extract text content from a Responses API response."""
        response.raise_for_status()
        result = response.json()

        assert (
            result.get("object") == "response"
        ), f"Expected object='response', got {result.get('object')}"
        assert result.get("id", "").startswith(
            "resp_"
        ), f"Expected id to start with 'resp_', got {result.get('id')}"
        assert (
            result.get("status") == "completed"
        ), f"Expected status='completed', got {result.get('status')}"

        output = result.get("output", [])
        assert len(output) > 0, "Response output is empty"

        msg = output[0]
        assert (
            msg.get("type") == "message"
        ), f"Expected output[0].type='message', got {msg.get('type')}"
        assert (
            msg.get("role") == "assistant"
        ), f"Expected role='assistant', got {msg.get('role')}"

        content_parts = msg.get("content", [])
        assert len(content_parts) > 0, "Message content is empty"
        assert (
            content_parts[0].get("type") == "output_text"
        ), f"Expected content[0].type='output_text', got {content_parts[0].get('type')}"

        return content_parts[0].get("text", "")

    def response_handler(self, response: Any) -> str:
        return ResponsesPayload.extract_content(response)


@dataclass
class ResponsesStreamPayload(BasePayload):
    """Streaming payload for the Responses API endpoint (/v1/responses).

    Validates SSE event structure and lifecycle ordering.
    """

    endpoint: str = "/v1/responses"
    http_stream: bool = True

    @staticmethod
    def extract_content(response):
        """Parse SSE stream and validate event structure."""
        import json

        response.raise_for_status()

Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
628
        events: list[tuple[str, Any]] = []
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
        event_type = ""
        for line in response.iter_lines(decode_unicode=True):
            if not line:
                continue
            if line.startswith("event: "):
                event_type = line[len("event: ") :]
            elif line.startswith("data: "):
                data_str = line[len("data: ") :]
                if data_str == "[DONE]":
                    events.append(("done", None))
                else:
                    events.append((event_type, json.loads(data_str)))

        event_types = [e[0] for e in events]

        # Validate lifecycle event ordering
        assert len(event_types) >= 2, f"Too few events: {event_types}"
        assert (
            event_types[0] == "response.created"
        ), f"First event should be response.created, got {event_types[0]}"
        assert (
            event_types[1] == "response.in_progress"
        ), f"Second event should be response.in_progress, got {event_types[1]}"

        non_done = [e for e in event_types if e != "done"]
        assert (
            non_done[-1] == "response.completed"
        ), f"Last real event should be response.completed, got {non_done[-1]}"

        # Validate text content events
        assert "response.output_item.added" in event_types, "Missing output_item.added"
        assert (
            "response.content_part.added" in event_types
        ), "Missing content_part.added"
        assert "response.output_text.delta" in event_types, "Missing output_text.delta"
        assert "response.output_text.done" in event_types, "Missing output_text.done"
        assert "response.content_part.done" in event_types, "Missing content_part.done"
        assert "response.output_item.done" in event_types, "Missing output_item.done"

        # Verify text deltas concatenate to the final text
        deltas = [e[1]["delta"] for e in events if e[0] == "response.output_text.delta"]
        done_events = [e for e in events if e[0] == "response.output_text.done"]
        assert (
            len(done_events) == 1
        ), f"Expected 1 output_text.done, got {len(done_events)}"
        full_text = "".join(deltas)
        assert (
            done_events[0][1]["text"] == full_text
        ), "Concatenated deltas don't match output_text.done text"

        return full_text

    def response_handler(self, response: Any) -> str:
        return ResponsesStreamPayload.extract_content(response)


685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
@dataclass
class AnthropicMessagesPayload(BasePayload):
    """Payload for the Anthropic Messages API endpoint (/v1/messages)."""

    endpoint: str = "/v1/messages"

    @staticmethod
    def extract_content(response):
        """Extract text content from an Anthropic Messages API response."""
        response.raise_for_status()
        result = response.json()

        assert (
            result.get("type") == "message"
        ), f"Expected type='message', got {result.get('type')}"
        assert result.get("id", "").startswith(
            "msg_"
        ), f"Expected id to start with 'msg_', got {result.get('id')}"
        assert (
            result.get("role") == "assistant"
        ), f"Expected role='assistant', got {result.get('role')}"
        assert result.get("stop_reason") in (
            "end_turn",
            "max_tokens",
            "stop_sequence",
            "tool_use",
        ), f"Unexpected stop_reason: {result.get('stop_reason')}"

        content = result.get("content", [])
        assert len(content) > 0, "Response content is empty"
        assert (
            content[0].get("type") == "text"
        ), f"Expected content[0].type='text', got {content[0].get('type')}"

        usage = result.get("usage", {})
        assert "input_tokens" in usage, "Missing input_tokens in usage"
        assert "output_tokens" in usage, "Missing output_tokens in usage"

        return content[0].get("text", "")

    def response_handler(self, response: Any) -> str:
        return AnthropicMessagesPayload.extract_content(response)


@dataclass
class AnthropicMessagesStreamPayload(BasePayload):
    """Streaming payload for the Anthropic Messages API endpoint (/v1/messages).

    Validates SSE event structure and lifecycle ordering per the Anthropic streaming spec.
    """

    endpoint: str = "/v1/messages"
    http_stream: bool = True

    @staticmethod
    def extract_content(response):
        """Parse SSE stream and validate Anthropic event structure."""
        import json

        response.raise_for_status()

        events = []
        event_type = ""
        for line in response.iter_lines(decode_unicode=True):
            if not line:
                continue
            if line.startswith("event: "):
                event_type = line[len("event: ") :]
            elif line.startswith("data: "):
                data_str = line[len("data: ") :]
                events.append((event_type, json.loads(data_str)))

        event_types = [e[0] for e in events]

        # Validate lifecycle event ordering
        assert len(event_types) >= 3, f"Too few events: {event_types}"
        assert (
            event_types[0] == "message_start"
        ), f"First event should be message_start, got {event_types[0]}"
        assert (
            event_types[-1] == "message_stop"
        ), f"Last event should be message_stop, got {event_types[-1]}"

        # Validate message_start structure
        msg_start = events[0][1]
        assert msg_start.get("type") == "message_start", "message_start missing type"
        message = msg_start.get("message", {})
        assert message.get("id", "").startswith(
            "msg_"
        ), "message id should start with msg_"
        assert message.get("role") == "assistant", "message role should be assistant"

        # Validate required event types
        assert "content_block_start" in event_types, "Missing content_block_start"
        assert "content_block_delta" in event_types, "Missing content_block_delta"
        assert "content_block_stop" in event_types, "Missing content_block_stop"
        assert "message_delta" in event_types, "Missing message_delta"

        # Validate message_delta has stop_reason
        delta_events = [e for e in events if e[0] == "message_delta"]
        assert (
            len(delta_events) == 1
        ), f"Expected 1 message_delta, got {len(delta_events)}"
        delta_body = delta_events[0][1].get("delta", {})
        assert delta_body.get("stop_reason") in (
            "end_turn",
            "max_tokens",
            "stop_sequence",
            "tool_use",
        ), f"Unexpected stop_reason in message_delta: {delta_body.get('stop_reason')}"

        # Collect text deltas
        deltas = []
        for e_type, e_data in events:
            if e_type == "content_block_delta":
                delta = e_data.get("delta", {})
                if delta.get("type") == "text_delta":
                    deltas.append(delta.get("text", ""))

        return "".join(deltas)

    def response_handler(self, response: Any) -> str:
        return AnthropicMessagesStreamPayload.extract_content(response)


810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
@dataclass
class EmbeddingPayload(BasePayload):
    """Payload for embeddings endpoint."""

    endpoint: str = "/v1/embeddings"

    @staticmethod
    def extract_embeddings(response):
        """
        Process embeddings API responses.
        """
        response.raise_for_status()
        result = response.json()
        assert "object" in result, "Missing 'object' in response"
        assert (
            result["object"] == "list"
        ), f"Expected object='list', got {result['object']}"
        assert "data" in result, "Missing 'data' in response"
        assert len(result["data"]) > 0, "Empty data in response"

        # Extract embedding vectors and validate structure
        embeddings = []
        for item in result["data"]:
            assert "object" in item, "Missing 'object' in embedding item"
            assert (
                item["object"] == "embedding"
            ), f"Expected object='embedding', got {item['object']}"
            assert "embedding" in item, "Missing 'embedding' vector in item"
            assert isinstance(
                item["embedding"], list
            ), "Embedding should be a list of floats"
            assert len(item["embedding"]) > 0, "Embedding vector should not be empty"
            embeddings.append(item["embedding"])

        # Return a summary string for validation
        return f"Generated {len(embeddings)} embeddings with dimension {len(embeddings[0])}"

    def response_handler(self, response: Any) -> str:
        return EmbeddingPayload.extract_embeddings(response)


851
852
853
854
855
856
857
858
859
860
861
862
@dataclass
class MetricCheck:
    """Definition of a metric validation check"""

    name: str
    pattern: Callable[[str], str]
    validator: Callable[[Any], bool]
    error_msg: Callable[[str, Any], str]
    success_msg: Callable[[str, Any], str]
    multiline: bool = False


863
864
@dataclass
class MetricsPayload(BasePayload):
865
866
867
868
869
870
    """Base class for Prometheus metrics validation payloads.

    Validates common dynamo_component_* metrics shared across all backends.
    Backend-specific subclasses handle engine-specific metrics.
    """

871
872
    endpoint: str = "/metrics"
    method: str = "GET"
873
    port: int = DefaultPort.SYSTEM1.value
874
875
876
877
878
879
880
881
882
883
    min_num_requests: int = 1

    def with_model(self, model):
        # Metrics does not use model in request body
        return self

    def response_handler(self, response: Any) -> str:
        response.raise_for_status()
        return response.text

884
885
    def _filter_bucket_metrics(self, content: str) -> str:
        """Filter out histogram bucket metrics to avoid count inflation"""
886
887
        content_lines = content.split("\n")
        filtered_lines = [line for line in content_lines if "_bucket{" not in line]
888
        return "\n".join(filtered_lines)
889

890
891
    def _get_common_metric_checks(self) -> list[MetricCheck]:
        """Get common dynamo_component_* metric checks shared across all backends"""
892
893
894
895
896
897
898
        prefix = prometheus_names.name_prefix.COMPONENT

        # Define metrics to check
        # Pattern matches: metric_name{labels} value OR metric_name value (labels optional)
        # Examples:
        #   - dynamo_component_requests_total{model="Qwen/Qwen3-0.6B"} 6
        #   - dynamo_component_uptime_seconds 150.390999059
899
        # Note: Supports scientific notation (e.g., 8.34e-05)
900
        def metric_pattern(name):
901
            return rf"{name}(?:\{{[^}}]*\}})?\s+([\d.eE+-]+)"
902

903
        return [
904
905
906
907
            MetricCheck(
                # Check: Minimum count of unique dynamo_component_* metrics
                name=f"{prefix}_*",
                pattern=lambda name: rf"^{prefix}_\w+",
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
908
909
910
911
912
913
914
915
916
                validator=lambda value: (
                    len(set(value)) >= 7
                ),  # 80% of typical ~13 metrics (excluding _bucket and removed kvstats metrics)
                error_msg=lambda name, value: (
                    f"Expected at least 7 unique {prefix}_* metrics, but found only {len(set(value))}"
                ),
                success_msg=lambda name, value: (
                    f"SUCCESS: Found {len(set(value))} unique {prefix}_* metrics (minimum required: 7)"
                ),
917
918
919
920
921
922
                multiline=True,
            ),
            MetricCheck(
                name=f"{prefix}_{prometheus_names.work_handler.REQUESTS_TOTAL}",
                pattern=metric_pattern,
                validator=lambda value: int(float(value)) >= self.min_num_requests,
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
923
924
925
926
927
928
                error_msg=lambda name, value: (
                    f"{name} has count {value} which is less than required {self.min_num_requests}"
                ),
                success_msg=lambda name, value: (
                    f"SUCCESS: Found {name} with count: {value}"
                ),
929
930
931
932
933
934
935
936
            ),
            MetricCheck(
                name=f"{prefix}_{prometheus_names.distributed_runtime.UPTIME_SECONDS}",
                pattern=metric_pattern,
                validator=lambda value: float(value) > 0,
                error_msg=lambda name, value: f"{name} should be > 0, but got {value}",
                success_msg=lambda name, value: f"SUCCESS: Found {name} = {value}s",
            ),
937
938
939
940
941
942
943
944
945
946
947
            MetricCheck(
                name=f"{prefix}_{prometheus_names.kvstats.TOTAL_BLOCKS}",
                pattern=metric_pattern,
                validator=lambda value: float(value) >= 0,
                error_msg=lambda name, value: f"{name} should be >= 0, but got {value}",
                success_msg=lambda name, value: f"SUCCESS: Found {name} = {value}",
            ),
            MetricCheck(
                name=f"{prefix}_{prometheus_names.kvstats.GPU_CACHE_USAGE_PERCENT}",
                pattern=metric_pattern,
                validator=lambda value: 0.0 <= float(value) <= 1.0,
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
948
949
950
                error_msg=lambda name, value: (
                    f"{name} should be between 0.0 and 1.0, but got {value}"
                ),
951
952
953
954
955
956
957
                success_msg=lambda name, value: f"SUCCESS: Found {name} = {value}",
            ),
            MetricCheck(
                name=f"{prefix}_{prometheus_names.model_info.LOAD_TIME_SECONDS}",
                pattern=metric_pattern,
                validator=lambda value: float(value) > 0,
                error_msg=lambda name, value: f"{name} should be > 0, but got {value}",
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
958
959
960
                success_msg=lambda name, value: (
                    f"SUCCESS: Found {name} = {float(value):.2f}s"
                ),
961
            ),
962
963
        ]

964
965
966
    def _get_backend_specific_checks(self) -> list[MetricCheck]:
        """Get backend-specific metric checks. Override in subclasses."""
        return []
967

968
969
970
971
    def _validate_metric_checks(
        self, metrics_to_check: list[MetricCheck], content: str
    ) -> None:
        """Run all metric checks and raise AssertionError if any fail"""
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
        for metric in metrics_to_check:
            # Special handling for multiline patterns (like counting unique metrics)
            if metric.multiline:
                pattern = metric.pattern(metric.name)
                matches = re.findall(pattern, content, re.MULTILINE)
                if not matches:
                    raise AssertionError(
                        f"Could not find any matches for pattern '{metric.name}'"
                    )

                # For multiline, pass the entire list to validator
                if metric.validator(matches):
                    logger.info(metric.success_msg(metric.name, matches))
                else:
                    raise AssertionError(metric.error_msg(metric.name, matches))
            else:
                # Standard single-value metric check
                if metric.name not in content:
                    raise AssertionError(
                        f"Metric '{metric.name}' not found in metrics output"
                    )

                pattern = metric.pattern(metric.name)
                matches = re.findall(pattern, content)
                if not matches:
                    raise AssertionError(
                        f"Could not parse value for metric '{metric.name}'"
                    )

                # For metrics with multiple values (like requests_total with different labels),
                # check if any match passes validation
                validation_passed = False
                last_value = None
                for match in matches:
                    last_value = match
                    if metric.validator(match):
                        logger.info(metric.success_msg(metric.name, match))
                        validation_passed = True
                        break

                if not validation_passed:
                    raise AssertionError(
                        metric.error_msg(
                            metric.name, last_value if last_value else "N/A"
                        )
                    )
1018

1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
    def validate(self, response: Any, content: str) -> None:
        """Validate Prometheus metrics output"""
        content = self._filter_bucket_metrics(content)

        # Collect all checks: common + backend-specific
        metrics_to_check = self._get_common_metric_checks()
        metrics_to_check.extend(self._get_backend_specific_checks())

        # Run all validations
        self._validate_metric_checks(metrics_to_check, content)


@dataclass
class VLLMMetricsPayload(MetricsPayload):
    """Metrics validation for vLLM backend with auto-label checks"""

    def _get_backend_specific_checks(self) -> list[MetricCheck]:
        """vLLM-specific metric checks"""
        checks = [
            MetricCheck(
                # Check: Minimum count of unique vllm:* metrics
                name="vllm:*",
                pattern=lambda name: r"^vllm:\w+",
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
1042
1043
1044
1045
1046
1047
1048
1049
1050
                validator=lambda value: (
                    len(set(value)) >= 56
                ),  # 80% of typical ~70 vllm metrics (excluding _bucket) as of 2026-02-05 (but will grow)
                error_msg=lambda name, value: (
                    f"Expected at least 56 unique vllm:* metrics, but found only {len(set(value))}"
                ),
                success_msg=lambda name, value: (
                    f"SUCCESS: Found {len(set(value))} unique vllm:* metrics (minimum required: 56)"
                ),
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
                multiline=True,
            )
        ]

        # Check required labels: auto-injected (from prometheus_names.labels) + injected by backend
        required_labels = [
            prometheus_names.labels.NAMESPACE,
            prometheus_names.labels.COMPONENT,
            prometheus_names.labels.ENDPOINT,
            prometheus_names.labels.MODEL,  # OpenAI standard (injected by all backends)
            prometheus_names.labels.MODEL_NAME,  # Alternative label (injected for compatibility)
        ]
        for label_name in required_labels:
            checks.append(
                MetricCheck(
                    name=f"vllm:* with {label_name}",
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
1067
1068
1069
1070
                    pattern=cast(
                        Callable[[str], str],
                        lambda name, lbl=label_name: rf'vllm:\w+\{{[^}}]*{lbl}="[^"]+"',
                    ),
1071
                    validator=lambda value: len(value) > 0,
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
                    error_msg=cast(
                        Callable[[str, Any], str],
                        lambda name, value, lbl=label_name: (
                            f"vLLM metrics missing label: {lbl}"
                        ),
                    ),
                    success_msg=cast(
                        Callable[[str, Any], str],
                        lambda name, value, lbl=label_name: (
                            f"SUCCESS: vLLM metrics include {lbl} label (found {len(value)} metrics)"
                        ),
                    ),
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
                    multiline=True,
                )
            )

        return checks


@dataclass
class LMCacheMetricsPayload(MetricsPayload):
    """Metrics validation for lmcache"""

    def _get_backend_specific_checks(self) -> list[MetricCheck]:
        """lmcache-specific metric checks"""
        return [
            MetricCheck(
                # Check: Minimum count of unique lmcache:* metrics
                name="lmcache:*",
                pattern=lambda name: r"^lmcache:\w+",
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
1102
1103
1104
1105
1106
1107
1108
1109
1110
                validator=lambda value: (
                    len(set(value)) >= 26
                ),  # 80% of typical ~33 lmcache metrics (excluding _bucket) as of 2026-02-05 (but will grow)
                error_msg=lambda name, value: (
                    f"Expected at least 26 unique lmcache:* metrics, but found only {len(set(value))}"
                ),
                success_msg=lambda name, value: (
                    f"SUCCESS: Found {len(set(value))} lmcache:* metrics (minimum required: 26)"
                ),
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
                multiline=True,
            )
        ]


@dataclass
class SGLangMetricsPayload(MetricsPayload):
    """Metrics validation for SGLang backend with auto-label checks"""

    def _get_backend_specific_checks(self) -> list[MetricCheck]:
        """SGLang-specific metric checks"""
        checks = [
            MetricCheck(
                # Check: Minimum count of unique sglang:* metrics
                name="sglang:*",
                pattern=lambda name: r"^sglang:\w+",
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
1127
1128
1129
1130
1131
1132
1133
1134
1135
                validator=lambda value: (
                    len(set(value)) >= 20
                ),  # 80% of typical ~25 sglang metrics (excluding _bucket) as of 2025-10-22 (but will grow)
                error_msg=lambda name, value: (
                    f"Expected at least 20 unique sglang:* metrics, but found only {len(set(value))}"
                ),
                success_msg=lambda name, value: (
                    f"SUCCESS: Found {len(set(value))} unique sglang:* metrics (minimum required: 20)"
                ),
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
                multiline=True,
            )
        ]

        # Check required labels: auto-injected (from prometheus_names.labels) + injected by backend
        required_labels = [
            prometheus_names.labels.NAMESPACE,
            prometheus_names.labels.COMPONENT,
            prometheus_names.labels.ENDPOINT,
            prometheus_names.labels.MODEL,  # OpenAI standard (injected by all backends)
            prometheus_names.labels.MODEL_NAME,  # Alternative label (injected for compatibility)
        ]
        for label_name in required_labels:
            checks.append(
                MetricCheck(
                    name=f"sglang:* with {label_name}",
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
1152
1153
1154
1155
1156
1157
                    pattern=cast(
                        Callable[[str], str],
                        lambda name, lbl=label_name: (
                            rf'sglang:\w+\{{[^}}]*{lbl}="[^"]+"'
                        ),
                    ),
1158
                    validator=lambda value: len(value) > 0,
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
                    error_msg=cast(
                        Callable[[str, Any], str],
                        lambda name, value, lbl=label_name: (
                            f"sglang metrics missing label: {lbl}"
                        ),
                    ),
                    success_msg=cast(
                        Callable[[str, Any], str],
                        lambda name, value, lbl=label_name: (
                            f"SUCCESS: sglang metrics include {lbl} label (found {len(value)} metrics)"
                        ),
                    ),
1171
1172
1173
1174
1175
1176
1177
                    multiline=True,
                )
            )

        return checks


1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
@dataclass
class SGLangDisaggMetricsPayload(SGLangMetricsPayload):
    """Metrics validation for SGLang disaggregated workers.

    Disagg workers (prefill/decode) expose fewer sglang:* metrics than
    aggregated workers because each only runs half the scheduler pipeline.
    Observed: ~14 unique sglang:* metrics vs ~25 for aggregated.
    """

    def _get_backend_specific_checks(self) -> list[MetricCheck]:
        checks = super()._get_backend_specific_checks()
        for check in checks:
            if check.name == "sglang:*":
                check.validator = lambda value: len(set(value)) >= 10
                check.error_msg = lambda name, value: (
                    f"Expected at least 10 unique sglang:* metrics, but found only {len(set(value))}"
                )
                check.success_msg = lambda name, value: (
                    f"SUCCESS: Found {len(set(value))} unique sglang:* metrics (minimum required: 10)"
                )
        return checks


1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
@dataclass
class TRTLLMMetricsPayload(MetricsPayload):
    """Metrics validation for TensorRT-LLM backend"""

    def _get_backend_specific_checks(self) -> list[MetricCheck]:
        """TRT-LLM-specific metric checks"""
        checks = [
            MetricCheck(
                # Check: Minimum count of unique trtllm_* metrics
                name="trtllm_*",
                pattern=lambda name: r"^trtllm_\w+",
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
1212
1213
1214
1215
1216
1217
1218
1219
1220
                validator=lambda value: (
                    len(set(value)) >= 4
                ),  # 80% of typical ~5 trtllm metrics (excluding _bucket) as of 2025-10-22 (but will grow)
                error_msg=lambda name, value: (
                    f"Expected at least 4 unique trtllm_* metrics, but found only {len(set(value))}"
                ),
                success_msg=lambda name, value: (
                    f"SUCCESS: Found {len(set(value))} unique trtllm_* metrics (minimum required: 4)"
                ),
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
                multiline=True,
            )
        ]

        # Check required labels: auto-injected (from prometheus_names.labels) + injected by backend
        required_labels = [
            prometheus_names.labels.NAMESPACE,
            prometheus_names.labels.COMPONENT,
            prometheus_names.labels.ENDPOINT,
            prometheus_names.labels.MODEL,  # OpenAI standard (injected by all backends)
            prometheus_names.labels.MODEL_NAME,  # Alternative label (injected for compatibility)
        ]
        for label_name in required_labels:
            checks.append(
                MetricCheck(
                    name=f"trtllm_* with {label_name}",
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
1237
1238
1239
1240
1241
1242
                    pattern=cast(
                        Callable[[str], str],
                        lambda name, lbl=label_name: (
                            rf'trtllm_\w+\{{[^}}]*{lbl}="[^"]+"'
                        ),
                    ),
1243
                    validator=lambda value: len(value) > 0,
Tzu-Ling Kan's avatar
Tzu-Ling Kan committed
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
                    error_msg=cast(
                        Callable[[str, Any], str],
                        lambda name, value, lbl=label_name: (
                            f"TRT-LLM metrics missing label: {lbl}"
                        ),
                    ),
                    success_msg=cast(
                        Callable[[str, Any], str],
                        lambda name, value, lbl=label_name: (
                            f"SUCCESS: TRT-LLM metrics include {lbl} label (found {len(value)} metrics)"
                        ),
                    ),
1256
1257
1258
1259
1260
1261
                    multiline=True,
                )
            )

        return checks

1262
1263
1264
1265
1266
1267
1268

def check_models_api(response):
    """Check if models API is working and returns models"""
    try:
        if response.status_code != 200:
            return False
        data = response.json()
Alec's avatar
Alec committed
1269
1270
1271
        time.sleep(
            1
        )  # temporary to avoid /completions race condition where we get 404 error
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
        return data.get("data") and len(data["data"]) > 0
    except Exception:
        return False


# Additional health check helpers
def check_health_generate(response):
    """Validate /health reports a 'generate' endpoint.

    Returns True if either of the following is found:
      - "endpoints" contains a string mentioning 'generate'
      - "instances" contains an object with endpoint == 'generate'
    """
    try:
        if response.status_code != 200:
            return False
        data = response.json()

        # Check endpoints list for any entry containing 'generate'
        endpoints = data.get("endpoints", []) or []
        for ep in endpoints:
            if isinstance(ep, str) and "generate" in ep:
Alec's avatar
Alec committed
1294
1295
1296
                time.sleep(
                    1
                )  # temporary to avoid /completions race condition where we get 404 error
1297
1298
1299
1300
1301
1302
                return True

        # Check instances for an entry with endpoint == 'generate'
        instances = data.get("instances", []) or []
        for inst in instances:
            if isinstance(inst, dict) and inst.get("endpoint") == "generate":
Alec's avatar
Alec committed
1303
1304
1305
                time.sleep(
                    1
                )  # temporary to avoid /completions race condition where we get 404 error
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
                return True

        return False
    except Exception:
        return False


# backwards compatiability
def completions_response_handler(response):
    return CompletionPayload.extract_text(response)


def chat_completions_response_handler(response):
    return ChatPayload.extract_content(response)
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367


@dataclass
class ImageGenerationPayload(BasePayload):
    """Payload for /v1/images/generations endpoint (diffusion image generation)."""

    endpoint: str = "/v1/images/generations"
    timeout: int = 300

    def response_handler(self, response: Any) -> str:
        response.raise_for_status()
        result = response.json()
        assert (
            "data" in result
        ), f"Missing 'data' in response. Keys: {list(result.keys())}"
        assert len(result["data"]) > 0, "Empty data in image response"
        entry = result["data"][0]
        if "url" in entry:
            assert entry["url"], "Image response url is empty"
            return entry["url"]
        assert entry.get("b64_json"), "Image response b64_json is empty"
        return "b64_image_returned"


@dataclass
class VideoGenerationPayload(BasePayload):
    """Payload for /v1/videos endpoint (diffusion video generation)."""

    endpoint: str = "/v1/videos"
    timeout: int = 600

    def response_handler(self, response: Any) -> str:
        response.raise_for_status()
        result = response.json()
        assert result.get("status") == "completed", (
            f"Video generation not completed. Status: {result.get('status')}, "
            f"Error: {result.get('error', 'none')}"
        )
        assert (
            "data" in result
        ), f"Missing 'data' in response. Keys: {list(result.keys())}"
        assert len(result["data"]) > 0, "Empty data in video response"
        entry = result["data"][0]
        if "url" in entry:
            assert entry["url"], "Video response url is empty"
            return entry["url"]
        assert entry.get("b64_json"), "Video response b64_json is empty"
        return "b64_video_returned"