payloads.py 27.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
17
import math
18
import re
Alec's avatar
Alec committed
19
import time
20
from copy import deepcopy
21
from dataclasses import dataclass, field
22
from typing import Any, Callable, Dict, List, Optional
23

24
25
import requests

26
from dynamo import prometheus_names  # type: ignore[attr-defined]
27
from tests.utils.constants import DefaultPort
28

29
30
31
32
33
34
35
36
logger = logging.getLogger(__name__)


@dataclass
class BasePayload:
    """Generic payload body plus expectations and repeat count."""

    body: Dict[str, Any]
37
    expected_response: List[Any]  # Can be List[str] or List[List[str]] for alternatives
38
39
    expected_log: List[str]
    repeat_count: int = 1
40
    timeout: int = 60
41
42
43

    # Connection info
    host: str = "localhost"
44
    port: int = DefaultPort.FRONTEND.value
45
46
    endpoint: str = ""
    method: str = "POST"
47
48
49
    # Optional additional ports used by specialized payloads (e.g. LoRA system/control-plane APIs).
    # This is intentionally empty by default to preserve prior semantics.
    system_ports: list[int] = field(default_factory=list)
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

    def url(self) -> str:
        ep = self.endpoint.lstrip("/")
        return f"http://{self.host}:{self.port}/{ep}"

    def with_model(self, model):
        p = deepcopy(self)
        if "model" not in p.body:
            p.body = {**p.body, "model": model}
        return p

    def response_handler(self, response: Any) -> str:
        """Extract a text representation of the response for logging/validation."""
        raise NotImplementedError("Subclasses must implement response_handler()")

    def validate(self, response: Any, content: str) -> None:
66
67
68
69
70
        """Default validation: ensure expected substrings appear in content.

        If expected_response is a list of strings, ANY one of them matching is sufficient (OR logic).
        This allows flexible validation where responses may vary but should contain at least one keyword.
        """
71
        if self.expected_response:
72
73
74
75
76
77
78
79
80
            # Check if content is empty
            if not content:
                logger.error("VALIDATION FAILED - Response content is empty")
                raise AssertionError(
                    f"Expected content not found in response. Expected any of: {self.expected_response}. Actual content is empty."
                )

            # Check if ANY of the expected strings are found (OR logic) and count matches
            found_keywords = []
81
            for expected in self.expected_response:
82
83
84
85
86
87
88
89
90
91
92
                if isinstance(expected, str) and expected.lower() in content.lower():
                    found_keywords.append(expected)

            if not found_keywords:
                logger.error(
                    f"VALIDATION FAILED - Actual content returned: {repr(content)}"
                )
                logger.error(
                    f"Expected to find at least one of: {self.expected_response}"
                )
                logger.error(f"Matches found: 0/{len(self.expected_response)}")
93
                raise AssertionError(
94
                    f"Expected content not found in response. Expected at least one of: {self.expected_response}. Actual content: {repr(content)}"
95
                )
96
97
98
99

            logger.info(
                f"SUCCESS: Found {len(found_keywords)}/{len(self.expected_response)} expected keywords: {found_keywords}"
            )
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    def process_response(self, response: Any) -> str:
        """Convenience: run response_handler then validate; return content."""
        content = self.response_handler(response)
        self.validate(response, content)
        return content


@dataclass
class ChatPayload(BasePayload):
    """Payload for chat completions endpoint."""

    endpoint: str = "/v1/chat/completions"

    @staticmethod
    def extract_content(response):
        """
        Process chat completions API responses.
        """
        response.raise_for_status()
        result = response.json()
121
122
123
124

        assert (
            "choices" in result
        ), f"Missing 'choices' in response. Response keys: {list(result.keys())}"
125
        assert len(result["choices"]) > 0, "Empty choices in response"
126
127
128
        assert (
            "message" in result["choices"][0]
        ), f"Missing 'message' in first choice. Choice keys: {list(result['choices'][0].keys())}"
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

        # Check for content in all possible fields where parsers might put output:
        # 1. content - standard message content
        # 2. reasoning_content - for models with reasoning parsers
        # 3. refusal - when the model refuses to answer
        # 4. tool_calls - for function/tool calling responses

        message = result["choices"][0]["message"]

        content = message.get("content", "")
        reasoning_content = message.get("reasoning_content", "")
        refusal = message.get("refusal", "")

        tool_calls = message.get("tool_calls", [])
        tool_content = ""
        if tool_calls:
            tool_content = ", ".join(
                call.get("function", {}).get("arguments", "")
                for call in tool_calls
                if call.get("function", {}).get("arguments")
            )

        for field_content in [content, reasoning_content, refusal, tool_content]:
            if field_content:
                return field_content

        raise ValueError(
            "All possible content fields are empty in message. "
            f"Checked: content={repr(content)}, reasoning_content={repr(reasoning_content)}, "
            f"refusal={repr(refusal)}, tool_calls={tool_calls}"
        )

    def response_handler(self, response: Any) -> str:
        return ChatPayload.extract_content(response)


165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
@dataclass
class ChatPayloadWithLogprobs(ChatPayload):
    """Chat payload that validates logprobs in response."""

    def validate(self, response: Any, content: str) -> None:
        """Validate response contains logprobs fields."""
        super().validate(response, content)

        result = response.json()
        choice = result["choices"][0]

        # Validate logprobs field exists
        assert "logprobs" in choice, "Missing 'logprobs' in choice"

        logprobs_data = choice["logprobs"]
        if logprobs_data is not None:
            assert "content" in logprobs_data, "Missing 'content' in logprobs"
            content_logprobs = logprobs_data["content"]

            if content_logprobs:
                # Validate structure of logprobs
                for item in content_logprobs:
                    assert "token" in item, "Missing 'token' in logprobs content"
                    assert "logprob" in item, "Missing 'logprob' in logprobs content"
                    assert (
                        "top_logprobs" in item
                    ), "Missing 'top_logprobs' in logprobs content"

193
194
195
196
197
198
199
200
                    # Sanity check: logprob should be valid (not nan/inf/positive)
                    logprob_val = item["logprob"]
                    assert not math.isnan(logprob_val), "logprob is NaN"
                    assert not math.isinf(logprob_val), "logprob is infinite"
                    assert (
                        logprob_val <= 0
                    ), f"logprob should be <= 0, got {logprob_val}"

201
202
203
204
205
                logger.info(
                    f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs"
                )


206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
@dataclass
class ToolCallingChatPayload(ChatPayload):
    """ChatPayload that validates tool calls in the response."""

    def __init__(self, *args, expected_tool_name: Optional[str] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.expected_tool_name = expected_tool_name

    def validate(self, response, content: str) -> None:
        """Validate that tool calls exist in the response."""
        # First run the standard validation
        super().validate(response, content)

        # Then validate tool calls specifically
        response_data = response.json()
        choices = response_data.get("choices", [])
        assert choices, "Response missing choices"

        message = choices[0].get("message", {})
        tool_calls = message.get("tool_calls", [])

        assert tool_calls, "Expected model to generate tool calls but none found"
        logger.info(f"Tool calls detected: {len(tool_calls)} call(s)")

        # Validate tool call structure
        for i, tc in enumerate(tool_calls):
            assert "function" in tc, f"Tool call {i} missing 'function' field"
            function = tc.get("function", {})
            assert "name" in function, f"Tool call {i} missing function name"
            assert "arguments" in function, f"Tool call {i} missing function arguments"
            logger.info(
                f"  [{i}] Function: {function.get('name')}, Args: {function.get('arguments')[:100]}..."
            )

        # If expected tool name is provided, validate it
        if self.expected_tool_name:
            tool_names = [tc.get("function", {}).get("name") for tc in tool_calls]
            assert (
                self.expected_tool_name in tool_names
            ), f"Expected tool '{self.expected_tool_name}' not found. Available tools: {tool_names}"
            logger.info(f"Expected tool '{self.expected_tool_name}' was called")


249
250
251
252
253
254
255
256
257
258
259
260
261
262
@dataclass
class LoraTestChatPayload(ChatPayload):
    """
    Chat payload that loads a LoRA adapter before sending inference requests.

    This payload first loads the specified LoRA adapter via the system API,
    then sends chat completion requests using the LoRA model.
    """

    def __init__(
        self,
        body: dict,
        lora_name: str,
        s3_uri: str,
263
        system_port: int = DefaultPort.SYSTEM1.value,
264
265
266
267
268
269
270
271
272
273
274
275
        repeat_count: int = 1,
        expected_response: Optional[list] = None,
        expected_log: Optional[list] = None,
        timeout: int = 60,
    ):
        super().__init__(
            body=body,
            repeat_count=repeat_count,
            expected_response=expected_response or [],
            expected_log=expected_log or [],
            timeout=timeout,
        )
276
        self.system_ports = [system_port]
277
278
279
280
281
282
283
284
285
286
287
288
        self.lora_name = lora_name
        self.s3_uri = s3_uri
        self._lora_loaded = False

    def _ensure_lora_loaded(self) -> None:
        """Ensure the LoRA adapter is loaded before making inference requests"""
        if not self._lora_loaded:
            # Import the load_lora_adapter function
            # Note: This import is done here to avoid circular dependencies
            from tests.serve.lora_utils import load_lora_adapter

            load_lora_adapter(
289
                system_port=self.system_ports[0],
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
                lora_name=self.lora_name,
                s3_uri=self.s3_uri,
                timeout=self.timeout,
            )

            # Wait for the LoRA model to appear in /v1/models
            models_url = f"http://{self.host}:{self.port}/v1/models"
            start_time = time.time()

            logger.info(
                f"Waiting for LoRA model '{self.lora_name}' to appear in /v1/models..."
            )

            while time.time() - start_time < self.timeout:
                try:
                    response = requests.get(models_url, timeout=5)
                    if response.status_code == 200:
                        data = response.json()
                        models = data.get("data", [])
                        model_ids = [m.get("id", "") for m in models]

                        if self.lora_name in model_ids:
                            logger.info(
                                f"LoRA model '{self.lora_name}' is now available"
                            )
                            self._lora_loaded = True
                            return

                        logger.debug(
                            f"Available models: {model_ids}, waiting for '{self.lora_name}'..."
                        )
                except requests.RequestException as e:
                    logger.debug(f"Error checking /v1/models: {e}")

                time.sleep(1)

            raise RuntimeError(
                f"Timeout: LoRA model '{self.lora_name}' did not appear in /v1/models within {self.timeout}s"
            )

    def url(self) -> str:
        """Load LoRA before first request, then return URL"""
        self._ensure_lora_loaded()
        return super().url()


336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
@dataclass
class CompletionPayload(BasePayload):
    """Payload for completions endpoint."""

    endpoint: str = "/v1/completions"

    @staticmethod
    def extract_text(response):
        """
        Process completions API responses.
        """
        response.raise_for_status()
        result = response.json()
        assert "choices" in result, "Missing 'choices' in response"
        assert len(result["choices"]) > 0, "Empty choices in response"
        assert "text" in result["choices"][0], "Missing 'text' in first choice"
        return result["choices"][0]["text"]

    def response_handler(self, response: Any) -> str:
        return CompletionPayload.extract_text(response)


358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
@dataclass
class CompletionPayloadWithLogprobs(CompletionPayload):
    """Completion payload that validates logprobs in response."""

    def validate(self, response: Any, content: str) -> None:
        """Validate response contains logprobs fields."""
        super().validate(response, content)

        result = response.json()
        choice = result["choices"][0]

        # Validate logprobs field exists
        assert "logprobs" in choice, "Missing 'logprobs' in choice"

        logprobs_data = choice["logprobs"]
        if logprobs_data is not None:
            assert (
                "token_logprobs" in logprobs_data
            ), "Missing 'token_logprobs' in logprobs"
            assert "tokens" in logprobs_data, "Missing 'tokens' in logprobs"

            token_logprobs = logprobs_data["token_logprobs"]
            tokens = logprobs_data["tokens"]

            if token_logprobs:
                assert len(token_logprobs) == len(
                    tokens
                ), "Mismatch between token_logprobs and tokens length"
386
387
388
389
390
391
392
393
394
395
396
397
398
399

                # Sanity check: each logprob should be valid (not nan/inf/positive)
                for i, logprob_val in enumerate(token_logprobs):
                    if logprob_val is not None:  # First token can be None
                        assert not math.isnan(
                            logprob_val
                        ), f"logprob at index {i} is NaN"
                        assert not math.isinf(
                            logprob_val
                        ), f"logprob at index {i} is infinite"
                        assert (
                            logprob_val <= 0
                        ), f"logprob at index {i} should be <= 0, got {logprob_val}"

400
401
402
403
404
                logger.info(
                    f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs"
                )


405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
@dataclass
class EmbeddingPayload(BasePayload):
    """Payload for embeddings endpoint."""

    endpoint: str = "/v1/embeddings"

    @staticmethod
    def extract_embeddings(response):
        """
        Process embeddings API responses.
        """
        response.raise_for_status()
        result = response.json()
        assert "object" in result, "Missing 'object' in response"
        assert (
            result["object"] == "list"
        ), f"Expected object='list', got {result['object']}"
        assert "data" in result, "Missing 'data' in response"
        assert len(result["data"]) > 0, "Empty data in response"

        # Extract embedding vectors and validate structure
        embeddings = []
        for item in result["data"]:
            assert "object" in item, "Missing 'object' in embedding item"
            assert (
                item["object"] == "embedding"
            ), f"Expected object='embedding', got {item['object']}"
            assert "embedding" in item, "Missing 'embedding' vector in item"
            assert isinstance(
                item["embedding"], list
            ), "Embedding should be a list of floats"
            assert len(item["embedding"]) > 0, "Embedding vector should not be empty"
            embeddings.append(item["embedding"])

        # Return a summary string for validation
        return f"Generated {len(embeddings)} embeddings with dimension {len(embeddings[0])}"

    def response_handler(self, response: Any) -> str:
        return EmbeddingPayload.extract_embeddings(response)


446
447
448
449
450
451
452
453
454
455
456
457
@dataclass
class MetricCheck:
    """Definition of a metric validation check"""

    name: str
    pattern: Callable[[str], str]
    validator: Callable[[Any], bool]
    error_msg: Callable[[str, Any], str]
    success_msg: Callable[[str, Any], str]
    multiline: bool = False


458
459
460
461
@dataclass
class MetricsPayload(BasePayload):
    endpoint: str = "/metrics"
    method: str = "GET"
462
    port: int = DefaultPort.SYSTEM1.value
463
    min_num_requests: int = 1
464
465
466
    backend: Optional[
        str
    ] = None  # Backend identifier for metrics validation (e.g., 'vllm', 'sglang', 'trtllm')
467
468
469
470
471
472
473
474
475
476

    def with_model(self, model):
        # Metrics does not use model in request body
        return self

    def response_handler(self, response: Any) -> str:
        response.raise_for_status()
        return response.text

    def validate(self, response: Any, content: str) -> None:
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        # Use backend from payload configuration
        backend = self.backend

        # Filter out _bucket metrics from content (histogram buckets inflate counts)
        content_lines = content.split("\n")
        filtered_lines = [line for line in content_lines if "_bucket{" not in line]
        content = "\n".join(filtered_lines)

        # Build full metric names with prefix
        prefix = prometheus_names.name_prefix.COMPONENT

        # Define metrics to check
        # Pattern matches: metric_name{labels} value OR metric_name value (labels optional)
        # Examples:
        #   - dynamo_component_requests_total{model="Qwen/Qwen3-0.6B"} 6
        #   - dynamo_component_uptime_seconds 150.390999059
        def metric_pattern(name):
            return rf"{name}(?:\{{[^}}]*\}})?\s+([\d.]+)"

        metrics_to_check = [
            MetricCheck(
                # Check: Minimum count of unique dynamo_component_* metrics
                name=f"{prefix}_*",
                pattern=lambda name: rf"^{prefix}_\w+",
                validator=lambda value: len(set(value))
502
503
504
                >= 11,  # 80% of typical ~17 metrics (excluding _bucket) as of 2025-12-02
                error_msg=lambda name, value: f"Expected at least 11 unique {prefix}_* metrics, but found only {len(set(value))}",
                success_msg=lambda name, value: f"SUCCESS: Found {len(set(value))} unique {prefix}_* metrics (minimum required: 11)",
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
                multiline=True,
            ),
            MetricCheck(
                name=f"{prefix}_{prometheus_names.work_handler.REQUESTS_TOTAL}",
                pattern=metric_pattern,
                validator=lambda value: int(float(value)) >= self.min_num_requests,
                error_msg=lambda name, value: f"{name} has count {value} which is less than required {self.min_num_requests}",
                success_msg=lambda name, value: f"SUCCESS: Found {name} with count: {value}",
            ),
            MetricCheck(
                name=f"{prefix}_{prometheus_names.distributed_runtime.UPTIME_SECONDS}",
                pattern=metric_pattern,
                validator=lambda value: float(value) > 0,
                error_msg=lambda name, value: f"{name} should be > 0, but got {value}",
                success_msg=lambda name, value: f"SUCCESS: Found {name} = {value}s",
            ),
            MetricCheck(
                name=f"{prefix}_{prometheus_names.kvstats.TOTAL_BLOCKS}",
                pattern=metric_pattern,
524
525
526
                validator=lambda value: int(float(value))
                >= 0,  # Allow 0 for SGLang (hardcoded issue in components/src/dynamo/sglang/publisher.py:70)
                error_msg=lambda name, value: f"{name} should be >= 0, but got {value}",
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
                success_msg=lambda name, value: f"SUCCESS: Found {name} = {value}",
            ),
        ]

        # Add backend-specific metric checks
        if backend == "vllm":
            metrics_to_check.append(
                MetricCheck(
                    # Check: Minimum count of unique vllm:* metrics
                    name="vllm:*",
                    pattern=lambda name: r"^vllm:\w+",
                    validator=lambda value: len(set(value))
                    >= 52,  # 80% of typical ~65 vllm metrics (excluding _bucket) as of 2025-10-22 (but will grow)
                    error_msg=lambda name, value: f"Expected at least 52 unique vllm:* metrics, but found only {len(set(value))}",
                    success_msg=lambda name, value: f"SUCCESS: Found {len(set(value))} unique vllm:* metrics (minimum required: 52)",
                    multiline=True,
543
                )
544
            )
545
546
547
548
549
550
551
552
553
554
555
556
557
        elif backend == "lmcache":
            metrics_to_check.append(
                MetricCheck(
                    # Check: Minimum count of unique lmcache:* metrics
                    name="lmcache:*",
                    pattern=lambda name: r"^lmcache:\w+",
                    validator=lambda value: len(set(value))
                    >= 1,  # At least 1 lmcache metric
                    error_msg=lambda name, value: f"Expected at least 1 lmcache:* metric, but found only {len(set(value))}",
                    success_msg=lambda name, value: f"SUCCESS: Found {len(set(value))} lmcache:* metrics",
                    multiline=True,
                )
            )
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        elif backend == "sglang":
            metrics_to_check.append(
                MetricCheck(
                    # Check: Minimum count of unique sglang:* metrics
                    name="sglang:*",
                    pattern=lambda name: r"^sglang:\w+",
                    validator=lambda value: len(set(value))
                    >= 20,  # 80% of typical ~25 sglang metrics (excluding _bucket) as of 2025-10-22 (but will grow)
                    error_msg=lambda name, value: f"Expected at least 20 unique sglang:* metrics, but found only {len(set(value))}",
                    success_msg=lambda name, value: f"SUCCESS: Found {len(set(value))} unique sglang:* metrics (minimum required: 20)",
                    multiline=True,
                )
            )
        elif backend == "trtllm":
            metrics_to_check.append(
                MetricCheck(
574
575
576
                    # Check: Minimum count of unique trtllm_* metrics
                    name="trtllm_*",
                    pattern=lambda name: r"^trtllm_\w+",
577
578
                    validator=lambda value: len(set(value))
                    >= 4,  # 80% of typical ~5 trtllm metrics (excluding _bucket) as of 2025-10-22 (but will grow)
579
580
                    error_msg=lambda name, value: f"Expected at least 4 unique trtllm_* metrics, but found only {len(set(value))}",
                    success_msg=lambda name, value: f"SUCCESS: Found {len(set(value))} unique trtllm_* metrics (minimum required: 4)",
581
582
583
                    multiline=True,
                )
            )
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631

        # Check all metrics
        for metric in metrics_to_check:
            # Special handling for multiline patterns (like counting unique metrics)
            if metric.multiline:
                pattern = metric.pattern(metric.name)
                matches = re.findall(pattern, content, re.MULTILINE)
                if not matches:
                    raise AssertionError(
                        f"Could not find any matches for pattern '{metric.name}'"
                    )

                # For multiline, pass the entire list to validator
                if metric.validator(matches):
                    logger.info(metric.success_msg(metric.name, matches))
                else:
                    raise AssertionError(metric.error_msg(metric.name, matches))
            else:
                # Standard single-value metric check
                if metric.name not in content:
                    raise AssertionError(
                        f"Metric '{metric.name}' not found in metrics output"
                    )

                pattern = metric.pattern(metric.name)
                matches = re.findall(pattern, content)
                if not matches:
                    raise AssertionError(
                        f"Could not parse value for metric '{metric.name}'"
                    )

                # For metrics with multiple values (like requests_total with different labels),
                # check if any match passes validation
                validation_passed = False
                last_value = None
                for match in matches:
                    last_value = match
                    if metric.validator(match):
                        logger.info(metric.success_msg(metric.name, match))
                        validation_passed = True
                        break

                if not validation_passed:
                    raise AssertionError(
                        metric.error_msg(
                            metric.name, last_value if last_value else "N/A"
                        )
                    )
632
633
634
635
636
637
638
639


def check_models_api(response):
    """Check if models API is working and returns models"""
    try:
        if response.status_code != 200:
            return False
        data = response.json()
Alec's avatar
Alec committed
640
641
642
        time.sleep(
            1
        )  # temporary to avoid /completions race condition where we get 404 error
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        return data.get("data") and len(data["data"]) > 0
    except Exception:
        return False


# Additional health check helpers
def check_health_generate(response):
    """Validate /health reports a 'generate' endpoint.

    Returns True if either of the following is found:
      - "endpoints" contains a string mentioning 'generate'
      - "instances" contains an object with endpoint == 'generate'
    """
    try:
        if response.status_code != 200:
            return False
        data = response.json()

        # Check endpoints list for any entry containing 'generate'
        endpoints = data.get("endpoints", []) or []
        for ep in endpoints:
            if isinstance(ep, str) and "generate" in ep:
Alec's avatar
Alec committed
665
666
667
                time.sleep(
                    1
                )  # temporary to avoid /completions race condition where we get 404 error
668
669
670
671
672
673
                return True

        # Check instances for an entry with endpoint == 'generate'
        instances = data.get("instances", []) or []
        for inst in instances:
            if isinstance(inst, dict) and inst.get("endpoint") == "generate":
Alec's avatar
Alec committed
674
675
676
                time.sleep(
                    1
                )  # temporary to avoid /completions race condition where we get 404 error
677
678
679
680
681
682
683
684
685
686
687
688
689
690
                return True

        return False
    except Exception:
        return False


# backwards compatiability
def completions_response_handler(response):
    return CompletionPayload.extract_text(response)


def chat_completions_response_handler(response):
    return ChatPayload.extract_content(response)