Unverified Commit 93208162 authored by Alec's avatar Alec Committed by GitHub
Browse files

refactor: standardize e2e tests across 3 frameworks (#2827)


Signed-off-by: default avataralec-flowers <aflowers@nvidia.com>
parent f0cea269
......@@ -3,114 +3,66 @@
import logging
import os
import re
import time
from dataclasses import dataclass
from typing import Any, List
from dataclasses import dataclass, field
import pytest
import requests
from tests.utils.managed_process import ManagedProcess
from tests.serve.common import run_serve_deployment
from tests.utils.engine_process import EngineConfig
from tests.utils.payload_builder import chat_payload_default, completion_payload_default
logger = logging.getLogger(__name__)
def validate_log_patterns(log_file, patterns):
"""Validate log patterns after test completion."""
if not os.path.exists(log_file):
raise AssertionError(f"Log file not found: {log_file}")
with open(log_file, "r", encoding="utf-8", errors="ignore") as f:
content = f.read()
compiled = [re.compile(p) for p in patterns]
missing = []
for pattern, rx in zip(patterns, compiled):
if not rx.search(content):
missing.append(pattern)
if missing:
# Include sample of log content for debugging
sample = content[-1000:] if len(content) > 1000 else content
raise AssertionError(
f"Missing expected log patterns: {missing}\n\nLog sample:\n{sample}"
)
return True
@dataclass
class SGLangConfig:
class SGLangConfig(EngineConfig):
"""Configuration for SGLang test scenarios"""
script_name: str
marks: List[Any]
name: str
class SGLangProcess(ManagedProcess):
"""Simple process manager for sglang shell scripts"""
def __init__(self, script_name, request):
self.port = 8000
sglang_dir = os.environ.get(
"SGLANG_DIR", "/workspace/components/backends/sglang"
)
script_path = os.path.join(sglang_dir, "launch", script_name)
# Verify script exists
if not os.path.exists(script_path):
raise FileNotFoundError(f"SGLang script not found: {script_path}")
stragglers: list[str] = field(default_factory=lambda: ["SGLANG:EngineCore"])
# Make script executable and run it
command = ["bash", script_path]
# Focus kv-router logs for kv_events run
env = os.environ.copy()
if script_name == "agg_router.sh":
env.setdefault(
"DYN_LOG",
"dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
)
super().__init__(
command=command,
env=env,
timeout=900,
display_output=True,
working_dir=sglang_dir,
health_check_ports=[], # Disable port health check
health_check_urls=[
(f"http://localhost:{self.port}/v1/models", self._check_models_api)
],
delayed_start=60, # Give SGLang more time to fully start
terminate_existing=False,
stragglers=[], # Don't kill any stragglers automatically
)
def _check_models_api(self, response):
"""Check if models API is working and returns models"""
try:
if response.status_code != 200:
return False
data = response.json()
return data.get("data") and len(data["data"]) > 0
except Exception:
return False
sglang_dir = os.environ.get("SGLANG_DIR", "/workspace/components/backends/sglang")
# SGLang test configurations
sglang_configs = {
"aggregated": SGLangConfig(
script_name="agg.sh", marks=[pytest.mark.gpu_1], name="aggregated"
name="aggregated",
directory=sglang_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1],
model="Qwen/Qwen3-0.6B",
env={},
models_port=8000,
request_payloads=[chat_payload_default(), completion_payload_default()],
),
"disaggregated": SGLangConfig(
script_name="disagg.sh", marks=[pytest.mark.gpu_2], name="disaggregated"
name="disaggregated",
directory=sglang_dir,
script_name="disagg.sh",
marks=[pytest.mark.gpu_2],
model="Qwen/Qwen3-0.6B",
env={},
models_port=8000,
request_payloads=[chat_payload_default(), completion_payload_default()],
),
"kv_events": SGLangConfig(
script_name="agg_router.sh", marks=[pytest.mark.gpu_2], name="kv_events"
name="kv_events",
directory=sglang_dir,
script_name="agg_router.sh",
marks=[pytest.mark.gpu_2],
model="Qwen/Qwen3-0.6B",
env={
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
},
models_port=8000,
request_payloads=[
chat_payload_default(
expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
]
)
],
),
}
......@@ -128,162 +80,11 @@ def sglang_config_test(request):
@pytest.mark.e2e
@pytest.mark.slow
@pytest.mark.sglang
def test_sglang_deployment(request, runtime_services, sglang_config_test):
"""Test SGLang deployment scenarios"""
# First check if sglang is available
try:
import sglang
logger.info(f"SGLang version: {sglang.__version__}")
except ImportError:
pytest.skip("SGLang not available")
def test_sglang_deployment(sglang_config_test, request, runtime_services):
"""Test SGLang deployment scenarios using common helpers"""
config = sglang_config_test
with SGLangProcess(config.script_name, request) as server:
# Test chat completions
prompts = [
"why is roger federer the best tennis player of all time?",
"why is novak djokovic not the best tennis player of all time?",
"why is rafa nadal a sneaky good grass court player?",
"explain the difference between federer and nadal's backhand.",
"who is the most clutch tennis player in history?",
]
responses = []
for prompt in prompts:
response = requests.post(
f"http://localhost:{server.port}/v1/chat/completions",
json={
"model": "Qwen/Qwen3-0.6B",
"messages": [
{
"role": "user",
"content": prompt,
}
],
"max_tokens": 50,
},
timeout=120,
)
assert response.status_code == 200
result = response.json()
assert "choices" in result
assert len(result["choices"]) > 0
content = result["choices"][0]["message"]["content"]
responses.append(content)
logger.info(f"SGLang {config.name} response: {content}")
# For kv_events (KV routing path), assert KV publisher/scheduler log lines appear
if config.name == "kv_events":
log_file = os.path.join(server.log_dir, "bash.log.txt")
assert os.path.exists(log_file), f"Log file not found: {log_file}"
patterns = [
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
]
validate_log_patterns(log_file, patterns)
# Test completions endpoint for disaggregated only
if config.name == "disaggregated":
response = requests.post(
f"http://localhost:{server.port}/v1/completions",
json={
"model": "Qwen/Qwen3-0.6B",
"prompt": "Roger Federer is the greatest tennis player of all time",
"max_tokens": 30,
},
timeout=120,
)
assert response.status_code == 200
result = response.json()
assert "choices" in result
assert len(result["choices"]) > 0
text = result["choices"][0]["text"]
assert len(text) > 0
logger.info(f"SGLang completions response: {text}")
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.sglang
@pytest.mark.slow
def test_metrics_labels(request, runtime_services):
"""
Test that the sglang backend correctly exports model labels in its metrics.
This test verifies that the model name appears as a label in the Prometheus metrics.
"""
logger.info("Starting test_metrics_labels for sglang backend")
# Configuration
model_path = "Qwen/Qwen3-0.6B"
metrics_port = 8081
# Build command to start sglang backend with metrics enabled
command = [
"python3",
"-m",
"dynamo.sglang",
"--model-path",
model_path,
"--mem-fraction-static",
"0.4", # Limit memory usage for testing
]
# Set environment for metrics
env = os.environ.copy()
env["DYN_SYSTEM_ENABLED"] = "true"
env["DYN_SYSTEM_PORT"] = str(metrics_port)
# Use ManagedProcess for consistent process management
with ManagedProcess(
command=command,
env=env,
timeout=120,
display_output=True,
health_check_urls=[
(f"http://localhost:{metrics_port}/metrics", lambda r: r.status_code == 200)
],
delayed_start=30, # Give SGLang time to initialize
):
# Give the backend a moment to fully initialize metrics
time.sleep(2)
# Fetch and verify metrics
logger.info("Fetching metrics to verify model label...")
response = requests.get(f"http://localhost:{metrics_port}/metrics", timeout=10)
assert response.status_code == 200, "Failed to fetch metrics"
metrics_text = response.text
logger.info(f"Metrics text: {metrics_text}")
# Parse the Prometheus metrics to find our label
pattern = rf'dynamo_component_requests_total\{{[^}}]*model="{re.escape(model_path)}"[^}}]*\}}\s+(\d+)'
matches = re.findall(pattern, metrics_text)
if matches:
initial_value = int(matches[0])
assert (
initial_value == 0
), f"Expected initial metric value to be 0, got {initial_value}"
else:
# Check if any dynamo_component metrics exist
if "dynamo_component" in metrics_text:
logger.info(
"✓ Metrics endpoint is working (found dynamo_component metrics)"
)
logger.warning(
"Note: dynamo_component_requests_total not found - likely because the engine didn't fully initialize"
)
logger.info("For complete testing, use a real pre-built TRT-LLM engine")
else:
pytest.fail("No dynamo_component metrics found at all")
run_serve_deployment(config, request)
@pytest.mark.skip(
......@@ -292,25 +93,4 @@ def test_metrics_labels(request, runtime_services):
def test_sglang_disagg_dp_attention(request, runtime_services):
"""Test sglang disaggregated with DP attention (requires 4 GPUs)"""
with SGLangProcess("disagg_dp_attn.sh", request) as server:
# Test chat completions with the DP attention model
response = requests.post(
f"http://localhost:{server.port}/v1/chat/completions",
json={
"model": "silence09/DeepSeek-R1-Small-2layers", # DP attention model
"messages": [{"role": "user", "content": "Tell me about MoE models"}],
"max_tokens": 50,
},
timeout=120,
)
# TODO: Once this is enabled, we can test out the rest of the HTTP endpoints around
# flush_cache and expert distribution recording
assert response.status_code == 200
result = response.json()
assert "choices" in result
assert len(result["choices"]) > 0
content = result["choices"][0]["message"]["content"]
assert len(content) > 0
logger.info(f"SGLang DP attention response: {content}")
# Kept for reference; this test uses a different launch path and is skipped
......@@ -3,18 +3,13 @@
import logging
import os
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
import pytest
from tests.serve.common import EngineConfig, create_payload_for_config
from tests.utils.deployment_graph import (
chat_completions_response_handler,
completions_response_handler,
metrics_handler,
)
from tests.utils.engine_process import EngineProcess
from tests.serve.common import run_serve_deployment
from tests.utils.engine_process import EngineConfig
from tests.utils.payload_builder import chat_payload_default, completion_payload_default
logger = logging.getLogger(__name__)
......@@ -23,138 +18,68 @@ logger = logging.getLogger(__name__)
class TRTLLMConfig(EngineConfig):
"""Configuration for trtllm test scenarios"""
stragglers: list[str] = field(default_factory=lambda: ["TRTLLM:EngineCore"])
class TRTLLMProcess(EngineProcess):
"""Simple process manager for trtllm shell scripts"""
def __init__(self, config: TRTLLMConfig, request):
self.port = 8000
self.backend_metrics_port = 8081
self.config = config
self.dir = config.directory
script_path = os.path.join(self.dir, "launch", config.script_name)
if not os.path.exists(script_path):
raise FileNotFoundError(f"trtllm script not found: {script_path}")
# Set these env vars to customize model launched by launch script to match test
os.environ["MODEL_PATH"] = config.model
os.environ["SERVED_MODEL_NAME"] = config.model
command = ["bash", script_path]
super().__init__(
command=command,
timeout=config.timeout,
display_output=True,
working_dir=self.dir,
health_check_ports=[], # Disable port health check
health_check_urls=[
(f"http://localhost:{self.port}/v1/models", self._check_models_api)
],
delayed_start=config.delayed_start,
terminate_existing=False, # If true, will call all bash processes including myself
stragglers=[], # Don't kill any stragglers automatically
log_dir=request.node.name,
)
def run_trtllm_test_case(config: TRTLLMConfig, request) -> None:
payload = create_payload_for_config(config)
with TRTLLMProcess(config, request) as server_process:
assert len(config.endpoints) == len(config.response_handlers)
for endpoint, response_handler in zip(
config.endpoints, config.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time()
elapsed = 0.0
request_body = (
payload.payload_chat
if endpoint == "v1/chat/completions"
else payload.payload_completions
)
for _ in range(payload.repeat_count):
if endpoint == "metrics":
response = server_process.get_metrics(
server_process.backend_metrics_port
)
response_handler(response)
else:
elapsed = time.time() - start_time
response = server_process.send_request(
url, payload=request_body, timeout=config.timeout - elapsed
)
server_process.check_response(payload, response, response_handler)
trtllm_dir = os.environ.get("TRTLLM_DIR", "/workspace/components/backends/trtllm")
# trtllm test configurations
trtllm_configs = {
"aggregated": TRTLLMConfig(
name="aggregated",
directory="/workspace/components/backends/trtllm",
directory=trtllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
models_port=8000,
request_payloads=[
chat_payload_default(),
completion_payload_default(),
],
),
"disaggregated": TRTLLMConfig(
name="disaggregated",
directory="/workspace/components/backends/trtllm",
directory=trtllm_dir,
script_name="disagg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.trtllm_marker],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
models_port=8000,
request_payloads=[
chat_payload_default(),
completion_payload_default(),
],
),
# TODO: These are sanity tests that the kv router examples launch
# and inference without error, but do not do detailed checks on the
# behavior of KV routing.
"aggregated_router": TRTLLMConfig(
name="aggregated_router",
directory="/workspace/components/backends/trtllm",
directory=trtllm_dir,
script_name="agg_router.sh",
marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
models_port=8000,
request_payloads=[
chat_payload_default(
expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
]
)
],
env={
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
},
),
"disaggregated_router": TRTLLMConfig(
name="disaggregated_router",
directory="/workspace/components/backends/trtllm",
directory=trtllm_dir,
script_name="disagg_router.sh",
marks=[pytest.mark.gpu_2, pytest.mark.trtllm_marker],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
),
"aggregated_metrics": TRTLLMConfig(
name="aggregated_metrics",
directory="/workspace/components/backends/trtllm",
script_name="agg_metrics.sh",
marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker],
endpoints=[
"v1/chat/completions",
"metrics",
], # Make a request to make sure the model is loaded and metrics are published.
response_handlers=[chat_completions_response_handler, metrics_handler],
model="Qwen/Qwen3-0.6B",
models_port=8000,
request_payloads=[
chat_payload_default(),
completion_payload_default(),
],
),
}
......@@ -170,24 +95,18 @@ def trtllm_config_test(request):
return trtllm_configs[request.param]
@pytest.mark.trtllm_marker
@pytest.mark.e2e
@pytest.mark.slow
def test_deployment(trtllm_config_test, request, runtime_services):
"""
Test dynamo deployments with different configurations.
"""
# runtime_services is used to start nats and etcd
logger = logging.getLogger(request.node.name)
logger.info("Starting test_deployment")
config = trtllm_config_test
logger.info(f"Using model: {config.model}")
logger.info(f"Script: {config.script_name}")
run_trtllm_test_case(config, request)
extra_env = {"MODEL_PATH": config.model, "SERVED_MODEL_NAME": config.model}
run_serve_deployment(config, request, extra_env=extra_env)
# TODO make this a normal guy
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.trtllm_marker
......@@ -209,11 +128,12 @@ def test_chat_only_aggregated_with_test_logits_processor(
directory=base.directory,
script_name=base.script_name, # agg.sh
marks=[], # not used by this direct test
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
request_payloads=[
chat_payload_default(),
],
model="Qwen/Qwen3-0.6B",
delayed_start=base.delayed_start,
timeout=base.timeout,
)
run_trtllm_test_case(config, request)
run_serve_deployment(config, request)
......@@ -3,177 +3,86 @@
import logging
import os
import time
from dataclasses import dataclass
from typing import List, Optional
from dataclasses import dataclass, field
import pytest
from tests.serve.common import EngineConfig
from tests.serve.common import create_payload_for_config as base_create_payload
from tests.utils.deployment_graph import (
Payload,
chat_completions_response_handler,
completions_response_handler,
from tests.serve.common import run_serve_deployment
from tests.utils.engine_process import EngineConfig
from tests.utils.payload_builder import (
chat_payload,
chat_payload_default,
completion_payload_default,
metric_payload_default,
)
from tests.utils.engine_process import EngineProcess
logger = logging.getLogger(__name__)
def create_payload_for_config(config: "VLLMConfig") -> Payload:
"""Create a payload using the model from the vLLM config"""
if config.name in ["multimodal_agg_llava", "multimodal_agg_qwen"]:
# Special handling for multimodal models
return Payload(
payload_chat={
"model": config.model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
},
},
],
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": False,
},
repeat_count=1,
expected_log=[],
expected_response=["bus"],
)
elif config.name == "multimodal_video_agg":
# Special handling for multimodal models
return Payload(
payload_chat={
"model": config.model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Describe the video in detail"},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
},
},
],
}
],
"max_tokens": 300,
"stream": False,
},
repeat_count=1,
expected_log=[],
expected_response=["rabbit"],
)
else:
# Use base implementation for standard text models
return base_create_payload(config)
@dataclass
class VLLMConfig(EngineConfig):
"""Configuration for vLLM test scenarios"""
args: Optional[List[str]] = None
class VLLMProcess(EngineProcess):
"""Simple process manager for vllm shell scripts"""
def __init__(self, config: VLLMConfig, request):
self.port = 8000
self.config = config
self.dir = config.directory
script_path = os.path.join(self.dir, "launch", config.script_name)
if not os.path.exists(script_path):
raise FileNotFoundError(f"vLLM script not found: {script_path}")
command = ["bash", script_path]
if config.args:
command.extend(config.args)
stragglers: list[str] = field(default_factory=lambda: ["VLLM:EngineCore"])
super().__init__(
command=command,
timeout=config.timeout,
display_output=True,
working_dir=self.dir,
health_check_ports=[], # Disable port health check
health_check_urls=[
(f"http://localhost:{self.port}/v1/models", self._check_models_api)
],
delayed_start=config.delayed_start,
terminate_existing=False, # If true, will call all bash processes including myself
stragglers=[], # Don't kill any stragglers automatically
log_dir=request.node.name,
)
vllm_dir = os.environ.get("VLLM_DIR", "/workspace/components/backends/vllm")
# vLLM test configurations
vllm_configs = {
"aggregated": VLLMConfig(
name="aggregated",
directory="/workspace/components/backends/vllm",
directory=vllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.vllm],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_1],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload_default(),
completion_payload_default(),
metric_payload_default(min_num_requests=6),
],
),
"agg-router": VLLMConfig(
name="agg-router",
directory="/workspace/components/backends/vllm",
directory=vllm_dir,
script_name="agg_router.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_2],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload_default(
expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
]
)
],
env={
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
},
),
"disaggregated": VLLMConfig(
name="disaggregated",
directory="/workspace/components/backends/vllm",
directory=vllm_dir,
script_name="disagg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_2],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload_default(),
completion_payload_default(),
],
),
"deepep": VLLMConfig(
name="deepep",
directory="/workspace/components/backends/vllm",
directory=vllm_dir,
script_name="dsr1_dep.sh",
marks=[
pytest.mark.gpu_2,
pytest.mark.vllm,
pytest.mark.h100,
],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="deepseek-ai/DeepSeek-V2-Lite",
args=[
script_args=[
"--model",
"deepseek-ai/DeepSeek-V2-Lite",
"--num-nodes",
......@@ -184,46 +93,84 @@ vllm_configs = {
"2",
],
timeout=700,
request_payloads=[
chat_payload_default(),
completion_payload_default(),
],
),
"multimodal_agg_llava": VLLMConfig(
name="multimodal_agg_llava",
directory="/workspace/examples/multimodal",
script_name="agg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
endpoints=["v1/chat/completions"],
response_handlers=[
chat_completions_response_handler,
],
marks=[pytest.mark.gpu_2],
model="llava-hf/llava-1.5-7b-hf",
args=["--model", "llava-hf/llava-1.5-7b-hf"],
script_args=["--model", "llava-hf/llava-1.5-7b-hf"],
request_payloads=[
chat_payload(
[
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
},
},
],
repeat_count=1,
expected_response=["bus"],
temperature=0.0,
)
],
),
"multimodal_agg_qwen": VLLMConfig(
name="multimodal_agg_qwen",
directory="/workspace/examples/multimodal",
script_name="agg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
endpoints=["v1/chat/completions"],
response_handlers=[
chat_completions_response_handler,
],
marks=[pytest.mark.gpu_2],
model="Qwen/Qwen2.5-VL-7B-Instruct",
delayed_start=0,
args=["--model", "Qwen/Qwen2.5-VL-7B-Instruct"],
script_args=["--model", "Qwen/Qwen2.5-VL-7B-Instruct"],
timeout=360,
request_payloads=[
chat_payload(
[
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
},
},
],
repeat_count=1,
expected_response=["bus"],
)
],
),
"multimodal_video_agg": VLLMConfig(
name="multimodal_video_agg",
directory="/workspace/examples/multimodal",
script_name="video_agg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
endpoints=["v1/chat/completions"],
response_handlers=[
chat_completions_response_handler,
],
marks=[pytest.mark.gpu_2],
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
delayed_start=0,
args=["--model", "llava-hf/LLaVA-NeXT-Video-7B-hf"],
script_args=["--model", "llava-hf/LLaVA-NeXT-Video-7B-hf"],
timeout=360,
request_payloads=[
chat_payload(
[
{"type": "text", "text": "Describe the video in detail"},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
},
},
],
repeat_count=1,
expected_response=["rabbit"],
)
],
),
# TODO: Enable this test case when we have 4 GPUs runners.
# "multimodal_disagg": VLLMConfig(
......@@ -231,13 +178,9 @@ vllm_configs = {
# directory="/workspace/examples/multimodal",
# script_name="disagg.sh",
# marks=[pytest.mark.gpu_4, pytest.mark.vllm],
# endpoints=["v1/chat/completions"],
# response_handlers=[
# chat_completions_response_handler,
# ],
# model="llava-hf/llava-1.5-7b-hf",
# delayed_start=45,
# args=["--model", "llava-hf/llava-1.5-7b-hf"],
# script_args=["--model", "llava-hf/llava-1.5-7b-hf"],
# ),
}
......@@ -253,41 +196,11 @@ def vllm_config_test(request):
return vllm_configs[request.param]
@pytest.mark.vllm
@pytest.mark.e2e
def test_serve_deployment(vllm_config_test, request, runtime_services):
"""
Test dynamo serve deployments with different graph configurations.
"""
# runtime_services is used to start nats and etcd
logger = logging.getLogger(request.node.name)
logger.info("Starting test_deployment")
config = vllm_config_test
payload = create_payload_for_config(config)
logger.info("Using model: %s", config.model)
logger.info("Script: %s", config.script_name)
with VLLMProcess(config, request) as server_process:
for endpoint, response_handler in zip(
config.endpoints, config.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time()
elapsed = 0.0
request_body = (
payload.payload_chat
if endpoint == "v1/chat/completions"
else payload.payload_completions
)
for _ in range(payload.repeat_count):
elapsed = time.time() - start_time
response = server_process.send_request(
url, payload=request_body, timeout=config.timeout - elapsed
)
server_process.check_response(payload, response, response_handler)
run_serve_deployment(config, request)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import time
from typing import Any, Dict
import requests
logger = logging.getLogger(__name__)
def send_request(
url: str,
payload: Dict[str, Any],
timeout: float = 30.0,
method: str = "POST",
log_level: int = 20,
) -> requests.Response:
"""
Send an HTTP request to the engine with detailed logging.
Args:
url: The endpoint URL
payload: The request payload (for GET, sent as query params)
timeout: Request timeout in seconds
method: HTTP method ("POST" or "GET")
Returns:
The response object
Raises:
requests.RequestException: If the request fails
"""
method_upper = method.upper()
payload_json = json.dumps(payload, indent=2)
curl_command = ""
if method_upper == "GET":
curl_command = f'curl "{url}"'
if payload:
# For GET requests, payload is sent as query parameters
query_params = "&".join(f"{k}={v}" for k, v in payload.items())
curl_command += f"?{query_params}"
else:
curl_command = f'curl -X {method_upper} "{url}"'
if method_upper == "POST":
curl_command += (
' \\\n -H "Content-Type: application/json" \\\n -d \''
+ payload_json
+ "'"
)
logger.log(log_level, "Sending request (curl equivalent):\n%s", curl_command)
start_time = time.time()
try:
if method_upper == "GET":
response = requests.get(url, params=payload, timeout=timeout)
elif method_upper == "POST":
response = requests.post(url, json=payload, timeout=timeout)
else:
# Fallback for other methods if needed
response = requests.request(
method_upper, url, json=payload, timeout=timeout
)
elapsed = time.time() - start_time
# Log response details
logger.log(
log_level,
"Received response: status=%d, elapsed=%.2fs",
response.status_code,
elapsed,
)
logger.debug("Response headers: %s", dict(response.headers))
# Try to log response body (truncated if too long)
try:
if response.headers.get("content-type", "").startswith("application/json"):
response_data = response.json()
response_str = json.dumps(response_data, indent=2)
if len(response_str) > 1000:
response_str = response_str[:1000] + "... (truncated)"
logger.debug("Response body: %s", response_str)
else:
response_text = response.text
if len(response_text) > 1000:
response_text = response_text[:1000] + "... (truncated)"
logger.debug("Response body: %s", response_text)
except Exception as e:
logger.debug("Could not parse response body: %s", e)
return response
except requests.exceptions.Timeout:
logger.error("Request timed out after %.2f seconds", timeout)
raise
except requests.exceptions.ConnectionError as e:
logger.error("Connection error: %s", e)
raise
except requests.exceptions.RequestException as e:
logger.error("Request failed: %s", e)
raise
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@dataclass
class Payload:
"""
Represents a test payload with expected response and log patterns.
"""
payload_chat: Dict[str, Any]
expected_response: List[str]
expected_log: List[str]
repeat_count: int = 1
payload_completions: Optional[Dict[str, Any]] = None
def chat_completions_response_handler(response):
"""
Process chat completions API responses.
"""
if response.status_code != 200:
return ""
result = response.json()
assert "choices" in result, "Missing 'choices' in response"
assert len(result["choices"]) > 0, "Empty choices in response"
assert "message" in result["choices"][0], "Missing 'message' in first choice"
message = result["choices"][0]["message"]
# Check for content in all possible fields where parsers might put output:
# 1. content - standard message content
# 2. reasoning_content - for models with reasoning parsers
# 3. refusal - when the model refuses to answer
# 4. tool_calls - for function/tool calling responses
content = message.get("content", "")
reasoning_content = message.get("reasoning_content", "")
refusal = message.get("refusal", "")
# Check for tool calls
tool_calls = message.get("tool_calls", [])
tool_content = ""
if tool_calls:
# Extract content from tool calls
tool_content = ", ".join(
call.get("function", {}).get("arguments", "")
for call in tool_calls
if call.get("function", {}).get("arguments")
)
# Return the first non-empty field in priority order
for field_content in [content, reasoning_content, refusal, tool_content]:
if field_content:
return field_content
# If all fields are empty, provide a detailed error
raise ValueError(
"All possible content fields are empty in message. "
f"Checked: content={repr(content)}, reasoning_content={repr(reasoning_content)}, "
f"refusal={repr(refusal)}, tool_calls={tool_calls}"
)
def completions_response_handler(response):
"""
Process completions API responses.
"""
if response.status_code != 200:
return ""
result = response.json()
assert "choices" in result, "Missing 'choices' in response"
assert len(result["choices"]) > 0, "Empty choices in response"
assert "text" in result["choices"][0], "Missing 'text' in first choice"
return result["choices"][0]["text"]
def metrics_handler(response):
"""Handler to check if metrics endpoint is working and contains model label."""
if response.status_code != 200:
raise AssertionError(
f"Metrics endpoint returned non-200 status code: {response.status_code}"
)
metrics_text = response.text
# Check for any model label in dynamo_component_requests_total metric
pattern = r'dynamo_component_requests_total\{[^}]*model="[^"]*"[^}]*\}\s+(\d+)'
matches = re.findall(pattern, metrics_text)
if not matches:
raise AssertionError(
"Metric 'dynamo_component_requests_total' with model label not found in metrics output"
)
# Since we send a request first, the counter should be > 0
for match in matches:
request_count = int(match)
if request_count > 0:
logger.info(
f"Found dynamo_component_requests_total with count: {request_count}"
)
return metrics_text
raise AssertionError(
"dynamo_component_requests_total exists but has count of 0 - request was not tracked"
)
......@@ -3,15 +3,19 @@
import json
import logging
import time
from typing import Any, Callable, Dict
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import requests
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import BasePayload, check_health_generate, check_models_api
logger = logging.getLogger(__name__)
FRONTEND_PORT = 8000
class EngineResponseError(Exception):
"""Custom exception for engine response errors"""
......@@ -19,107 +23,38 @@ class EngineResponseError(Exception):
pass
class EngineProcess(ManagedProcess):
"""Base class for LLM engine processes (vLLM, TRT-LLM, etc.)"""
def _check_models_api(self, response):
"""Check if models API is working and returns models"""
try:
if response.status_code != 200:
return False
data = response.json()
return data.get("data") and len(data["data"]) > 0
except Exception:
return False
def get_metrics(self, port=8081):
"""Curl the metrics endpoint and return the response."""
metrics_url = f"http://localhost:{port}/metrics"
logger.info(f"Curling metrics endpoint: {metrics_url}")
try:
response = requests.get(metrics_url, timeout=10)
logger.info(
f"Metrics endpoint responded with status: {response.status_code}"
)
return response
except requests.RequestException as e:
logger.error(f"Failed to curl metrics endpoint: {e}")
raise
def send_request(
self, url: str, payload: Dict[str, Any], timeout: float = 30.0
) -> requests.Response:
"""
Send a POST request to the engine with detailed logging.
class EngineLogError(Exception):
"""Custom exception for engine log validation errors"""
Args:
url: The endpoint URL
payload: The request payload
timeout: Request timeout in seconds
Returns:
The response object
pass
Raises:
requests.RequestException: If the request fails
"""
# Log the request as a curl command for easy reproduction
payload_json = json.dumps(payload, indent=2)
curl_command = f'curl -X POST "{url}" \\\n -H "Content-Type: application/json" \\\n -d \'{payload_json}\''
logger.info("Sending request (curl equivalent):\n%s", curl_command)
@dataclass
class EngineConfig:
"""Base configuration for engine test scenarios"""
start_time = time.time()
try:
response = requests.post(url, json=payload, timeout=timeout)
elapsed = time.time() - start_time
name: str
directory: str
script_name: str
marks: List[Any]
request_payloads: List[BasePayload]
model: str
# Log response details
logger.info(
"Received response: status=%d, elapsed=%.2fs",
response.status_code,
elapsed,
)
script_args: Optional[List[str]] = None
models_port: int = 8000
timeout: int = 600
delayed_start: int = 0
env: Dict[str, str] = field(default_factory=dict)
stragglers: list[str] = field(default_factory=list)
logger.debug("Response headers: %s", dict(response.headers))
# Try to log response body (truncated if too long)
try:
if response.headers.get("content-type", "").startswith(
"application/json"
):
response_data = response.json()
response_str = json.dumps(response_data, indent=2)
if len(response_str) > 1000:
response_str = response_str[:1000] + "... (truncated)"
logger.debug("Response body: %s", response_str)
else:
response_text = response.text
if len(response_text) > 1000:
response_text = response_text[:1000] + "... (truncated)"
logger.debug("Response body: %s", response_text)
except Exception as e:
logger.debug("Could not parse response body: %s", e)
return response
except requests.exceptions.Timeout:
logger.error("Request timed out after %.2f seconds", timeout)
raise
except requests.exceptions.ConnectionError as e:
logger.error("Connection error: %s", e)
raise
except requests.exceptions.RequestException as e:
logger.error("Request failed: %s", e)
raise
class EngineProcess(ManagedProcess):
"""Base class for LLM engine processes (vLLM, TRT-LLM, etc.)"""
def check_response(
self,
payload: Any,
payload: BasePayload,
response: requests.Response,
response_handler: Callable[[Any], str],
) -> None:
"""
Check if the response is valid and contains expected content.
......@@ -151,30 +86,93 @@ class EngineProcess(ManagedProcess):
raise EngineResponseError(error_msg)
# Extract content using the handler
try:
content = response_handler(response)
content = payload.process_response(response)
logger.info(
"Extracted content: \n%s",
content[:200] + "..." if len(content) > 200 else content,
content[:200] + "..."
if isinstance(content, str) and len(content) > 200
else content,
)
except AssertionError as e:
raise EngineResponseError(str(e))
except Exception as e:
raise EngineResponseError(f"Failed to extract content from response: {e}")
raise EngineResponseError(f"Failed to handle response: {e}")
# Optionally validate expected log patterns after response handling
if payload.expected_log:
self.validate_expected_logs(payload.expected_log)
def validate_expected_logs(self, patterns: Any) -> None:
"""Validate that all regex patterns are present in the current logs.
Reads the full log via ManagedProcess.read_logs and searches for each
provided regex pattern. Raises EngineLogError if any are missing.
"""
import re # local import to keep module load minimal
content = self.read_logs() or ""
if not content:
raise EngineResponseError("Response contained empty content")
raise EngineLogError(
f"Log file not available or empty at path: {self.log_path}"
)
if hasattr(payload, "expected_response") and payload.expected_response:
missing_expected = []
for expected in payload.expected_response:
if expected not in content:
missing_expected.append(expected)
compiled = [re.compile(p) for p in patterns]
missing = []
for pattern, rx in zip(patterns, compiled):
if not rx.search(content):
missing.append(pattern)
if missing_expected:
raise EngineResponseError(
f"Expected content not found in response. Missing: {missing_expected}"
)
else:
logger.info(
f"SUCCESS: All expected content ({payload.expected_response}) found in response"
)
if missing:
sample = content[-1000:] if len(content) > 1000 else content
raise EngineLogError(
f"Missing expected log patterns: {missing}\n\nLog sample:\n{sample}"
)
logger.info(f"SUCCESS: All expected log patterns: {patterns} found")
@classmethod
def from_script(
cls,
config: EngineConfig,
request: Any,
extra_env: Optional[Dict[str, str]] = None,
) -> "EngineProcess":
"""Factory to create an EngineProcess configured to run a launch script."""
assert isinstance(config, EngineConfig), "Must use an instance of EngineConfig"
directory = config.directory
script_path = os.path.join(directory, "launch", config.script_name)
if not os.path.exists(script_path):
raise FileNotFoundError(f"Script not found: {script_path}")
command: List[str] = ["bash", script_path]
if config.script_args:
command.extend(config.script_args)
env = os.environ.copy()
if getattr(config, "env", None):
env.update(config.env)
if extra_env:
env.update(extra_env)
return cls(
command=command,
env=env,
timeout=config.timeout,
display_output=True,
working_dir=directory,
health_check_ports=[],
health_check_urls=[
(f"http://localhost:{config.models_port}/v1/models", check_models_api),
(
f"http://localhost:{config.models_port}/health",
check_health_generate,
),
],
delayed_start=config.delayed_start,
terminate_existing=False,
stragglers=config.stragglers,
log_dir=request.node.name,
)
......@@ -73,6 +73,7 @@ class ManagedProcess:
env: Optional[dict] = None
health_check_ports: List[int] = field(default_factory=list)
health_check_urls: List[Any] = field(default_factory=list)
health_check_funcs: List[Any] = field(default_factory=list)
delayed_start: int = 0
timeout: int = 300
working_dir: Optional[str] = None
......@@ -93,6 +94,24 @@ class ManagedProcess:
_tee_proc = None
_sed_proc = None
@property
def log_path(self):
"""Return the absolute path to the process log file if available."""
return self._log_path
def read_logs(self) -> str:
"""Read and return the entire contents of the process log file.
Returns an empty string if the log file is not yet available.
"""
try:
if self._log_path and os.path.exists(self._log_path):
with open(self._log_path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
except Exception as e:
self._logger.warning("Could not read log file %s: %s", self._log_path, e)
return ""
def __enter__(self):
try:
self._logger = logging.getLogger(self.__class__.__name__)
......@@ -109,6 +128,7 @@ class ManagedProcess:
time.sleep(self.delayed_start)
elapsed = self._check_ports(self.timeout)
self._check_urls(self.timeout - elapsed)
self._check_funcs(self.timeout - elapsed)
return self
......@@ -121,44 +141,73 @@ class ManagedProcess:
)
raise
def __exit__(self, exc_type, exc_val, exc_tb):
self._terminate_process_group()
def _cleanup_stragglers(self):
"""Clean up straggler processes - called during exit and signal handling"""
try:
if self.stragglers or self.straggler_commands:
self._logger.info(
"Checking for straggler processes: stragglers=%s, straggler_commands=%s",
self.stragglers,
self.straggler_commands,
)
process_list = [self.proc, self._tee_proc, self._sed_proc]
for process in process_list:
if process:
for ps_process in psutil.process_iter(["name", "cmdline"]):
try:
if process.stdout:
process.stdout.close()
if process.stdin:
process.stdin.close()
terminate_process_tree(process.pid, self._logger)
process.wait()
except Exception as e:
self._logger.warning("Error terminating process: %s", e)
if self.data_dir:
self._remove_directory(self.data_dir)
for ps_process in psutil.process_iter(["name", "cmdline"]):
try:
if ps_process.name() in self.stragglers:
self._logger.info(
"Terminating Straggler %s %s", ps_process.name(), ps_process.pid
)
terminate_process_tree(ps_process.pid, self._logger)
for cmdline in self.straggler_commands:
if cmdline in " ".join(ps_process.cmdline()):
process_name = ps_process.name()
if process_name in self.stragglers:
self._logger.info(
"Terminating Straggler Cmdline %s %s %s",
ps_process.name(),
ps_process.pid,
cmdline,
"Terminating Straggler %s %s", process_name, ps_process.pid
)
terminate_process_tree(ps_process.pid, self._logger)
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
# Process may have terminated or become inaccessible during iteration
pass
# Check command line arguments
cmdline = ps_process.cmdline()
cmdline_str = " ".join(cmdline) if cmdline else ""
for straggler_cmd in self.straggler_commands:
if straggler_cmd in cmdline_str:
self._logger.info(
"Terminating Straggler Cmdline %s %s %s",
process_name,
ps_process.pid,
straggler_cmd,
)
terminate_process_tree(ps_process.pid, self._logger)
break # Avoid terminating the same process multiple times
except (
psutil.NoSuchProcess,
psutil.AccessDenied,
psutil.ZombieProcess,
):
# Process may have terminated or become inaccessible during iteration
pass
except Exception as e:
# Catch any other unexpected errors to ensure cleanup continues
self._logger.warning("Error checking process: %s", e)
except Exception as e:
# Ensure that any error in straggler cleanup doesn't prevent other cleanup
self._logger.error("Error during straggler cleanup: %s", e)
def __exit__(self, exc_type, exc_val, exc_tb):
try:
self._terminate_process_group()
process_list = [self.proc, self._tee_proc, self._sed_proc]
for process in process_list:
if process:
try:
if process.stdout:
process.stdout.close()
if process.stdin:
process.stdin.close()
terminate_process_tree(process.pid, self._logger)
process.wait()
except Exception as e:
self._logger.warning("Error terminating process: %s", e)
if self.data_dir:
self._remove_directory(self.data_dir)
finally:
# Always run straggler cleanup, even if interrupted
self._cleanup_stragglers()
def _start_process(self):
assert self._command_name
......@@ -327,7 +376,7 @@ class ManagedProcess:
elapsed += self._check_url(url, timeout - elapsed)
return elapsed
def _check_url(self, url, timeout=30, sleep=1, log_interval=10):
def _check_url(self, url, timeout=30, sleep=1, log_interval=20):
if isinstance(url, tuple):
response_check = url[1]
url = url[0]
......@@ -403,6 +452,70 @@ class ManagedProcess:
)
raise RuntimeError("FAILED: Check URL: %s" % url)
def _check_funcs(self, timeout):
elapsed = 0.0
for func in self.health_check_funcs:
elapsed += self._check_func(func, timeout - elapsed)
return elapsed
def _check_func(self, func, timeout=30, sleep=1, log_interval=20):
start_time = time.time()
func_name = getattr(func, "__name__", str(func))
self._logger.info("Running custom health check '%s'", func_name)
elapsed = 0.0
attempt = 0
last_log_time = 0.0
while elapsed < timeout:
self._check_process_alive("while waiting for health check")
attempt += 1
check_failed = False
failure_reason = None
try:
# Prefer functions that accept remaining timeout; fall back to no-arg call
try:
result = func(timeout - elapsed)
except TypeError:
result = func()
if bool(result):
self._logger.info(
"SUCCESS: Custom health check '%s' passed (attempt=%d, elapsed=%.1fs)",
func_name,
attempt,
elapsed,
)
return time.time() - start_time
else:
check_failed = True
failure_reason = "returned False"
except Exception as e:
check_failed = True
failure_reason = f"exception: {e}"
if check_failed and elapsed - last_log_time >= log_interval:
self._logger.info(
"Still waiting on custom health check '%s' (%s) (attempt=%d, elapsed=%.1fs)",
func_name,
failure_reason,
attempt,
elapsed,
)
last_log_time = elapsed
time.sleep(sleep)
elapsed = time.time() - start_time
self._logger.error(
"FAILED: Custom health check '%s' (attempts=%d, elapsed=%.1fs)",
func_name,
attempt,
elapsed,
)
raise RuntimeError("FAILED: Custom health check")
def _terminate_existing(self):
if self.terminate_existing:
for proc in psutil.process_iter(["name", "cmdline"]):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Union
from tests.utils.client import send_request
from tests.utils.payloads import ChatPayload, CompletionPayload, MetricsPayload
# Common default text prompt used across tests
TEXT_PROMPT = "Tell me a short joke about AI."
def chat_payload_default(
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
max_tokens: int = 150,
temperature: float = 0.1,
stream: bool = False,
) -> ChatPayload:
return ChatPayload(
body={
"messages": [
{
"role": "user",
"content": TEXT_PROMPT,
}
],
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or ["AI"],
)
def completion_payload_default(
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
max_tokens: int = 150,
temperature: float = 0.1,
stream: bool = False,
) -> CompletionPayload:
return CompletionPayload(
body={
"prompt": TEXT_PROMPT,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or ["AI"],
)
def metric_payload_default(
min_num_requests: int,
repeat_count: int = 1,
expected_log: Optional[List[str]] = None,
) -> MetricsPayload:
return MetricsPayload(
body={},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=[],
min_num_requests=min_num_requests,
)
def chat_payload(
content: Union[str, List[Dict[str, Any]]],
repeat_count: int = 1,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
max_tokens: int = 300,
temperature: Optional[float] = None,
stream: bool = False,
) -> ChatPayload:
body: Dict[str, Any] = {
"messages": [
{
"role": "user",
"content": content,
}
],
"max_tokens": max_tokens,
"stream": stream,
}
if temperature is not None:
body["temperature"] = temperature
return ChatPayload(
body=body,
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or [],
)
def completion_payload(
prompt: str,
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
max_tokens: int = 150,
temperature: float = 0.1,
stream: bool = False,
) -> CompletionPayload:
return CompletionPayload(
body={
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or [],
)
# Build small request-based health checks for chat and completions
# these should only be used as a last resort. Generally want to use an actual health check
def make_chat_health_check(port: int, model: str):
def _check_chat_endpoint(remaining_timeout: float = 30.0) -> bool:
payload = chat_payload_default(
repeat_count=1,
expected_response=[],
max_tokens=8,
temperature=0.0,
stream=False,
).with_model(model)
payload.port = port
try:
resp = send_request(
payload.url(),
payload.body,
timeout=min(max(1.0, remaining_timeout), 5.0),
method=payload.method,
log_level=10,
)
# Validate structure only; expected_response is empty
_ = payload.response_handler(resp)
return True
except Exception:
return False
return _check_chat_endpoint
def make_completions_health_check(port: int, model: str):
def _check_completions_endpoint(remaining_timeout: float = 30.0) -> bool:
payload = completion_payload_default(
repeat_count=1,
expected_response=[],
max_tokens=8,
temperature=0.0,
stream=False,
).with_model(model)
payload.port = port
try:
resp = send_request(
payload.url(),
payload.body,
timeout=min(max(1.0, remaining_timeout), 5.0),
method=payload.method,
log_level=10,
)
out = payload.response_handler(resp)
if not out:
raise ValueError("")
return True
except Exception:
return False
return _check_completions_endpoint
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List
logger = logging.getLogger(__name__)
@dataclass
class BasePayload:
"""Generic payload body plus expectations and repeat count."""
body: Dict[str, Any]
expected_response: List[str]
expected_log: List[str]
repeat_count: int = 1
timeout: int = 30
# Connection info
host: str = "localhost"
port: int = 8000
endpoint: str = ""
method: str = "POST"
def url(self) -> str:
ep = self.endpoint.lstrip("/")
return f"http://{self.host}:{self.port}/{ep}"
def with_model(self, model):
p = deepcopy(self)
if "model" not in p.body:
p.body = {**p.body, "model": model}
return p
def response_handler(self, response: Any) -> str:
"""Extract a text representation of the response for logging/validation."""
raise NotImplementedError("Subclasses must implement response_handler()")
def validate(self, response: Any, content: str) -> None:
"""Default validation: ensure expected substrings appear in content."""
if self.expected_response:
missing_expected = []
for expected in self.expected_response:
if not content or expected not in content:
missing_expected.append(expected)
if missing_expected:
raise AssertionError(
f"Expected content not found in response. Missing: {missing_expected}"
)
logger.info(f"SUCCESS: All expected_responses: {self.expected_response} found.")
def process_response(self, response: Any) -> str:
"""Convenience: run response_handler then validate; return content."""
content = self.response_handler(response)
self.validate(response, content)
return content
@dataclass
class ChatPayload(BasePayload):
"""Payload for chat completions endpoint."""
endpoint: str = "/v1/chat/completions"
@staticmethod
def extract_content(response):
"""
Process chat completions API responses.
"""
response.raise_for_status()
result = response.json()
assert "choices" in result, "Missing 'choices' in response"
assert len(result["choices"]) > 0, "Empty choices in response"
assert "message" in result["choices"][0], "Missing 'message' in first choice"
# Check for content in all possible fields where parsers might put output:
# 1. content - standard message content
# 2. reasoning_content - for models with reasoning parsers
# 3. refusal - when the model refuses to answer
# 4. tool_calls - for function/tool calling responses
message = result["choices"][0]["message"]
content = message.get("content", "")
reasoning_content = message.get("reasoning_content", "")
refusal = message.get("refusal", "")
tool_calls = message.get("tool_calls", [])
tool_content = ""
if tool_calls:
tool_content = ", ".join(
call.get("function", {}).get("arguments", "")
for call in tool_calls
if call.get("function", {}).get("arguments")
)
for field_content in [content, reasoning_content, refusal, tool_content]:
if field_content:
return field_content
raise ValueError(
"All possible content fields are empty in message. "
f"Checked: content={repr(content)}, reasoning_content={repr(reasoning_content)}, "
f"refusal={repr(refusal)}, tool_calls={tool_calls}"
)
def response_handler(self, response: Any) -> str:
return ChatPayload.extract_content(response)
@dataclass
class CompletionPayload(BasePayload):
"""Payload for completions endpoint."""
endpoint: str = "/v1/completions"
@staticmethod
def extract_text(response):
"""
Process completions API responses.
"""
response.raise_for_status()
result = response.json()
assert "choices" in result, "Missing 'choices' in response"
assert len(result["choices"]) > 0, "Empty choices in response"
assert "text" in result["choices"][0], "Missing 'text' in first choice"
return result["choices"][0]["text"]
def response_handler(self, response: Any) -> str:
return CompletionPayload.extract_text(response)
@dataclass
class MetricsPayload(BasePayload):
endpoint: str = "/metrics"
method: str = "GET"
port: int = 8081
min_num_requests: int = 1
def with_model(self, model):
# Metrics does not use model in request body
return self
def response_handler(self, response: Any) -> str:
response.raise_for_status()
return response.text
def validate(self, response: Any, content: str) -> None:
pattern = r'dynamo_component_requests_total\{[^}]*model="[^"]*"[^}]*\}\s+(\d+)'
matches = re.findall(pattern, content)
if not matches:
raise AssertionError(
"Metric 'dynamo_component_requests_total' with model label not found in metrics output"
)
for match in matches:
request_count = int(match)
if request_count >= self.min_num_requests:
logger.info(
f"SUCCESS: Found dynamo_component_requests_total with count: {request_count}"
)
return
raise AssertionError(
f"dynamo_component_requests_total exists but has count {request_count} which is less than required {self.min_num_requests}"
)
def check_models_api(response):
"""Check if models API is working and returns models"""
try:
if response.status_code != 200:
return False
data = response.json()
return data.get("data") and len(data["data"]) > 0
except Exception:
return False
# Additional health check helpers
def check_health_generate(response):
"""Validate /health reports a 'generate' endpoint.
Returns True if either of the following is found:
- "endpoints" contains a string mentioning 'generate'
- "instances" contains an object with endpoint == 'generate'
"""
try:
if response.status_code != 200:
return False
data = response.json()
# Check endpoints list for any entry containing 'generate'
endpoints = data.get("endpoints", []) or []
for ep in endpoints:
if isinstance(ep, str) and "generate" in ep:
return True
# Check instances for an entry with endpoint == 'generate'
instances = data.get("instances", []) or []
for inst in instances:
if isinstance(inst, dict) and inst.get("endpoint") == "generate":
return True
return False
except Exception:
return False
# backwards compatiability
def completions_response_handler(response):
return CompletionPayload.extract_text(response)
def chat_completions_response_handler(response):
return ChatPayload.extract_content(response)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment