Unverified Commit 93208162 authored by Alec's avatar Alec Committed by GitHub
Browse files

refactor: standardize e2e tests across 3 frameworks (#2827)


Signed-off-by: default avataralec-flowers <aflowers@nvidia.com>
parent f0cea269
...@@ -3,114 +3,66 @@ ...@@ -3,114 +3,66 @@
import logging import logging
import os import os
import re from dataclasses import dataclass, field
import time
from dataclasses import dataclass
from typing import Any, List
import pytest import pytest
import requests
from tests.utils.managed_process import ManagedProcess from tests.serve.common import run_serve_deployment
from tests.utils.engine_process import EngineConfig
from tests.utils.payload_builder import chat_payload_default, completion_payload_default
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def validate_log_patterns(log_file, patterns):
"""Validate log patterns after test completion."""
if not os.path.exists(log_file):
raise AssertionError(f"Log file not found: {log_file}")
with open(log_file, "r", encoding="utf-8", errors="ignore") as f:
content = f.read()
compiled = [re.compile(p) for p in patterns]
missing = []
for pattern, rx in zip(patterns, compiled):
if not rx.search(content):
missing.append(pattern)
if missing:
# Include sample of log content for debugging
sample = content[-1000:] if len(content) > 1000 else content
raise AssertionError(
f"Missing expected log patterns: {missing}\n\nLog sample:\n{sample}"
)
return True
@dataclass @dataclass
class SGLangConfig: class SGLangConfig(EngineConfig):
"""Configuration for SGLang test scenarios""" """Configuration for SGLang test scenarios"""
script_name: str stragglers: list[str] = field(default_factory=lambda: ["SGLANG:EngineCore"])
marks: List[Any]
name: str
class SGLangProcess(ManagedProcess):
"""Simple process manager for sglang shell scripts"""
def __init__(self, script_name, request):
self.port = 8000
sglang_dir = os.environ.get(
"SGLANG_DIR", "/workspace/components/backends/sglang"
)
script_path = os.path.join(sglang_dir, "launch", script_name)
# Verify script exists
if not os.path.exists(script_path):
raise FileNotFoundError(f"SGLang script not found: {script_path}")
# Make script executable and run it
command = ["bash", script_path]
# Focus kv-router logs for kv_events run sglang_dir = os.environ.get("SGLANG_DIR", "/workspace/components/backends/sglang")
env = os.environ.copy()
if script_name == "agg_router.sh":
env.setdefault(
"DYN_LOG",
"dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
)
super().__init__(
command=command,
env=env,
timeout=900,
display_output=True,
working_dir=sglang_dir,
health_check_ports=[], # Disable port health check
health_check_urls=[
(f"http://localhost:{self.port}/v1/models", self._check_models_api)
],
delayed_start=60, # Give SGLang more time to fully start
terminate_existing=False,
stragglers=[], # Don't kill any stragglers automatically
)
def _check_models_api(self, response):
"""Check if models API is working and returns models"""
try:
if response.status_code != 200:
return False
data = response.json()
return data.get("data") and len(data["data"]) > 0
except Exception:
return False
# SGLang test configurations
sglang_configs = { sglang_configs = {
"aggregated": SGLangConfig( "aggregated": SGLangConfig(
script_name="agg.sh", marks=[pytest.mark.gpu_1], name="aggregated" name="aggregated",
directory=sglang_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1],
model="Qwen/Qwen3-0.6B",
env={},
models_port=8000,
request_payloads=[chat_payload_default(), completion_payload_default()],
), ),
"disaggregated": SGLangConfig( "disaggregated": SGLangConfig(
script_name="disagg.sh", marks=[pytest.mark.gpu_2], name="disaggregated" name="disaggregated",
directory=sglang_dir,
script_name="disagg.sh",
marks=[pytest.mark.gpu_2],
model="Qwen/Qwen3-0.6B",
env={},
models_port=8000,
request_payloads=[chat_payload_default(), completion_payload_default()],
), ),
"kv_events": SGLangConfig( "kv_events": SGLangConfig(
script_name="agg_router.sh", marks=[pytest.mark.gpu_2], name="kv_events" name="kv_events",
directory=sglang_dir,
script_name="agg_router.sh",
marks=[pytest.mark.gpu_2],
model="Qwen/Qwen3-0.6B",
env={
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
},
models_port=8000,
request_payloads=[
chat_payload_default(
expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
]
)
],
), ),
} }
...@@ -128,162 +80,11 @@ def sglang_config_test(request): ...@@ -128,162 +80,11 @@ def sglang_config_test(request):
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.slow
@pytest.mark.sglang @pytest.mark.sglang
def test_sglang_deployment(request, runtime_services, sglang_config_test): def test_sglang_deployment(sglang_config_test, request, runtime_services):
"""Test SGLang deployment scenarios""" """Test SGLang deployment scenarios using common helpers"""
# First check if sglang is available
try:
import sglang
logger.info(f"SGLang version: {sglang.__version__}")
except ImportError:
pytest.skip("SGLang not available")
config = sglang_config_test config = sglang_config_test
run_serve_deployment(config, request)
with SGLangProcess(config.script_name, request) as server:
# Test chat completions
prompts = [
"why is roger federer the best tennis player of all time?",
"why is novak djokovic not the best tennis player of all time?",
"why is rafa nadal a sneaky good grass court player?",
"explain the difference between federer and nadal's backhand.",
"who is the most clutch tennis player in history?",
]
responses = []
for prompt in prompts:
response = requests.post(
f"http://localhost:{server.port}/v1/chat/completions",
json={
"model": "Qwen/Qwen3-0.6B",
"messages": [
{
"role": "user",
"content": prompt,
}
],
"max_tokens": 50,
},
timeout=120,
)
assert response.status_code == 200
result = response.json()
assert "choices" in result
assert len(result["choices"]) > 0
content = result["choices"][0]["message"]["content"]
responses.append(content)
logger.info(f"SGLang {config.name} response: {content}")
# For kv_events (KV routing path), assert KV publisher/scheduler log lines appear
if config.name == "kv_events":
log_file = os.path.join(server.log_dir, "bash.log.txt")
assert os.path.exists(log_file), f"Log file not found: {log_file}"
patterns = [
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
]
validate_log_patterns(log_file, patterns)
# Test completions endpoint for disaggregated only
if config.name == "disaggregated":
response = requests.post(
f"http://localhost:{server.port}/v1/completions",
json={
"model": "Qwen/Qwen3-0.6B",
"prompt": "Roger Federer is the greatest tennis player of all time",
"max_tokens": 30,
},
timeout=120,
)
assert response.status_code == 200
result = response.json()
assert "choices" in result
assert len(result["choices"]) > 0
text = result["choices"][0]["text"]
assert len(text) > 0
logger.info(f"SGLang completions response: {text}")
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.sglang
@pytest.mark.slow
def test_metrics_labels(request, runtime_services):
"""
Test that the sglang backend correctly exports model labels in its metrics.
This test verifies that the model name appears as a label in the Prometheus metrics.
"""
logger.info("Starting test_metrics_labels for sglang backend")
# Configuration
model_path = "Qwen/Qwen3-0.6B"
metrics_port = 8081
# Build command to start sglang backend with metrics enabled
command = [
"python3",
"-m",
"dynamo.sglang",
"--model-path",
model_path,
"--mem-fraction-static",
"0.4", # Limit memory usage for testing
]
# Set environment for metrics
env = os.environ.copy()
env["DYN_SYSTEM_ENABLED"] = "true"
env["DYN_SYSTEM_PORT"] = str(metrics_port)
# Use ManagedProcess for consistent process management
with ManagedProcess(
command=command,
env=env,
timeout=120,
display_output=True,
health_check_urls=[
(f"http://localhost:{metrics_port}/metrics", lambda r: r.status_code == 200)
],
delayed_start=30, # Give SGLang time to initialize
):
# Give the backend a moment to fully initialize metrics
time.sleep(2)
# Fetch and verify metrics
logger.info("Fetching metrics to verify model label...")
response = requests.get(f"http://localhost:{metrics_port}/metrics", timeout=10)
assert response.status_code == 200, "Failed to fetch metrics"
metrics_text = response.text
logger.info(f"Metrics text: {metrics_text}")
# Parse the Prometheus metrics to find our label
pattern = rf'dynamo_component_requests_total\{{[^}}]*model="{re.escape(model_path)}"[^}}]*\}}\s+(\d+)'
matches = re.findall(pattern, metrics_text)
if matches:
initial_value = int(matches[0])
assert (
initial_value == 0
), f"Expected initial metric value to be 0, got {initial_value}"
else:
# Check if any dynamo_component metrics exist
if "dynamo_component" in metrics_text:
logger.info(
"✓ Metrics endpoint is working (found dynamo_component metrics)"
)
logger.warning(
"Note: dynamo_component_requests_total not found - likely because the engine didn't fully initialize"
)
logger.info("For complete testing, use a real pre-built TRT-LLM engine")
else:
pytest.fail("No dynamo_component metrics found at all")
@pytest.mark.skip( @pytest.mark.skip(
...@@ -292,25 +93,4 @@ def test_metrics_labels(request, runtime_services): ...@@ -292,25 +93,4 @@ def test_metrics_labels(request, runtime_services):
def test_sglang_disagg_dp_attention(request, runtime_services): def test_sglang_disagg_dp_attention(request, runtime_services):
"""Test sglang disaggregated with DP attention (requires 4 GPUs)""" """Test sglang disaggregated with DP attention (requires 4 GPUs)"""
with SGLangProcess("disagg_dp_attn.sh", request) as server: # Kept for reference; this test uses a different launch path and is skipped
# Test chat completions with the DP attention model
response = requests.post(
f"http://localhost:{server.port}/v1/chat/completions",
json={
"model": "silence09/DeepSeek-R1-Small-2layers", # DP attention model
"messages": [{"role": "user", "content": "Tell me about MoE models"}],
"max_tokens": 50,
},
timeout=120,
)
# TODO: Once this is enabled, we can test out the rest of the HTTP endpoints around
# flush_cache and expert distribution recording
assert response.status_code == 200
result = response.json()
assert "choices" in result
assert len(result["choices"]) > 0
content = result["choices"][0]["message"]["content"]
assert len(content) > 0
logger.info(f"SGLang DP attention response: {content}")
...@@ -3,18 +3,13 @@ ...@@ -3,18 +3,13 @@
import logging import logging
import os import os
import time from dataclasses import dataclass, field
from dataclasses import dataclass
import pytest import pytest
from tests.serve.common import EngineConfig, create_payload_for_config from tests.serve.common import run_serve_deployment
from tests.utils.deployment_graph import ( from tests.utils.engine_process import EngineConfig
chat_completions_response_handler, from tests.utils.payload_builder import chat_payload_default, completion_payload_default
completions_response_handler,
metrics_handler,
)
from tests.utils.engine_process import EngineProcess
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -23,138 +18,68 @@ logger = logging.getLogger(__name__) ...@@ -23,138 +18,68 @@ logger = logging.getLogger(__name__)
class TRTLLMConfig(EngineConfig): class TRTLLMConfig(EngineConfig):
"""Configuration for trtllm test scenarios""" """Configuration for trtllm test scenarios"""
stragglers: list[str] = field(default_factory=lambda: ["TRTLLM:EngineCore"])
class TRTLLMProcess(EngineProcess):
"""Simple process manager for trtllm shell scripts"""
def __init__(self, config: TRTLLMConfig, request):
self.port = 8000
self.backend_metrics_port = 8081
self.config = config
self.dir = config.directory
script_path = os.path.join(self.dir, "launch", config.script_name)
if not os.path.exists(script_path):
raise FileNotFoundError(f"trtllm script not found: {script_path}")
# Set these env vars to customize model launched by launch script to match test
os.environ["MODEL_PATH"] = config.model
os.environ["SERVED_MODEL_NAME"] = config.model
command = ["bash", script_path]
super().__init__(
command=command,
timeout=config.timeout,
display_output=True,
working_dir=self.dir,
health_check_ports=[], # Disable port health check
health_check_urls=[
(f"http://localhost:{self.port}/v1/models", self._check_models_api)
],
delayed_start=config.delayed_start,
terminate_existing=False, # If true, will call all bash processes including myself
stragglers=[], # Don't kill any stragglers automatically
log_dir=request.node.name,
)
def run_trtllm_test_case(config: TRTLLMConfig, request) -> None:
payload = create_payload_for_config(config)
with TRTLLMProcess(config, request) as server_process:
assert len(config.endpoints) == len(config.response_handlers)
for endpoint, response_handler in zip(
config.endpoints, config.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time()
elapsed = 0.0
request_body = (
payload.payload_chat
if endpoint == "v1/chat/completions"
else payload.payload_completions
)
for _ in range(payload.repeat_count):
if endpoint == "metrics":
response = server_process.get_metrics(
server_process.backend_metrics_port
)
response_handler(response)
else:
elapsed = time.time() - start_time
response = server_process.send_request(
url, payload=request_body, timeout=config.timeout - elapsed
)
server_process.check_response(payload, response, response_handler)
trtllm_dir = os.environ.get("TRTLLM_DIR", "/workspace/components/backends/trtllm")
# trtllm test configurations # trtllm test configurations
trtllm_configs = { trtllm_configs = {
"aggregated": TRTLLMConfig( "aggregated": TRTLLMConfig(
name="aggregated", name="aggregated",
directory="/workspace/components/backends/trtllm", directory=trtllm_dir,
script_name="agg.sh", script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker], marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
models_port=8000,
request_payloads=[
chat_payload_default(),
completion_payload_default(),
],
), ),
"disaggregated": TRTLLMConfig( "disaggregated": TRTLLMConfig(
name="disaggregated", name="disaggregated",
directory="/workspace/components/backends/trtllm", directory=trtllm_dir,
script_name="disagg.sh", script_name="disagg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.trtllm_marker], marks=[pytest.mark.gpu_2, pytest.mark.trtllm_marker],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
models_port=8000,
request_payloads=[
chat_payload_default(),
completion_payload_default(),
],
), ),
# TODO: These are sanity tests that the kv router examples launch
# and inference without error, but do not do detailed checks on the
# behavior of KV routing.
"aggregated_router": TRTLLMConfig( "aggregated_router": TRTLLMConfig(
name="aggregated_router", name="aggregated_router",
directory="/workspace/components/backends/trtllm", directory=trtllm_dir,
script_name="agg_router.sh", script_name="agg_router.sh",
marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker], marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
models_port=8000,
request_payloads=[
chat_payload_default(
expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
]
)
],
env={
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
},
), ),
"disaggregated_router": TRTLLMConfig( "disaggregated_router": TRTLLMConfig(
name="disaggregated_router", name="disaggregated_router",
directory="/workspace/components/backends/trtllm", directory=trtllm_dir,
script_name="disagg_router.sh", script_name="disagg_router.sh",
marks=[pytest.mark.gpu_2, pytest.mark.trtllm_marker], marks=[pytest.mark.gpu_2, pytest.mark.trtllm_marker],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
),
"aggregated_metrics": TRTLLMConfig(
name="aggregated_metrics",
directory="/workspace/components/backends/trtllm",
script_name="agg_metrics.sh",
marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker],
endpoints=[
"v1/chat/completions",
"metrics",
], # Make a request to make sure the model is loaded and metrics are published.
response_handlers=[chat_completions_response_handler, metrics_handler],
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
models_port=8000,
request_payloads=[
chat_payload_default(),
completion_payload_default(),
],
), ),
} }
...@@ -170,24 +95,18 @@ def trtllm_config_test(request): ...@@ -170,24 +95,18 @@ def trtllm_config_test(request):
return trtllm_configs[request.param] return trtllm_configs[request.param]
@pytest.mark.trtllm_marker
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.slow
def test_deployment(trtllm_config_test, request, runtime_services): def test_deployment(trtllm_config_test, request, runtime_services):
""" """
Test dynamo deployments with different configurations. Test dynamo deployments with different configurations.
""" """
# runtime_services is used to start nats and etcd
logger = logging.getLogger(request.node.name)
logger.info("Starting test_deployment")
config = trtllm_config_test config = trtllm_config_test
logger.info(f"Using model: {config.model}") extra_env = {"MODEL_PATH": config.model, "SERVED_MODEL_NAME": config.model}
logger.info(f"Script: {config.script_name}") run_serve_deployment(config, request, extra_env=extra_env)
run_trtllm_test_case(config, request)
# TODO make this a normal guy
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.trtllm_marker @pytest.mark.trtllm_marker
...@@ -209,11 +128,12 @@ def test_chat_only_aggregated_with_test_logits_processor( ...@@ -209,11 +128,12 @@ def test_chat_only_aggregated_with_test_logits_processor(
directory=base.directory, directory=base.directory,
script_name=base.script_name, # agg.sh script_name=base.script_name, # agg.sh
marks=[], # not used by this direct test marks=[], # not used by this direct test
endpoints=["v1/chat/completions"], request_payloads=[
response_handlers=[chat_completions_response_handler], chat_payload_default(),
],
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
delayed_start=base.delayed_start, delayed_start=base.delayed_start,
timeout=base.timeout, timeout=base.timeout,
) )
run_trtllm_test_case(config, request) run_serve_deployment(config, request)
...@@ -3,177 +3,86 @@ ...@@ -3,177 +3,86 @@
import logging import logging
import os import os
import time from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import List, Optional
import pytest import pytest
from tests.serve.common import EngineConfig from tests.serve.common import run_serve_deployment
from tests.serve.common import create_payload_for_config as base_create_payload from tests.utils.engine_process import EngineConfig
from tests.utils.deployment_graph import ( from tests.utils.payload_builder import (
Payload, chat_payload,
chat_completions_response_handler, chat_payload_default,
completions_response_handler, completion_payload_default,
metric_payload_default,
) )
from tests.utils.engine_process import EngineProcess
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_payload_for_config(config: "VLLMConfig") -> Payload:
"""Create a payload using the model from the vLLM config"""
if config.name in ["multimodal_agg_llava", "multimodal_agg_qwen"]:
# Special handling for multimodal models
return Payload(
payload_chat={
"model": config.model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
},
},
],
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": False,
},
repeat_count=1,
expected_log=[],
expected_response=["bus"],
)
elif config.name == "multimodal_video_agg":
# Special handling for multimodal models
return Payload(
payload_chat={
"model": config.model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Describe the video in detail"},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
},
},
],
}
],
"max_tokens": 300,
"stream": False,
},
repeat_count=1,
expected_log=[],
expected_response=["rabbit"],
)
else:
# Use base implementation for standard text models
return base_create_payload(config)
@dataclass @dataclass
class VLLMConfig(EngineConfig): class VLLMConfig(EngineConfig):
"""Configuration for vLLM test scenarios""" """Configuration for vLLM test scenarios"""
args: Optional[List[str]] = None stragglers: list[str] = field(default_factory=lambda: ["VLLM:EngineCore"])
class VLLMProcess(EngineProcess):
"""Simple process manager for vllm shell scripts"""
def __init__(self, config: VLLMConfig, request):
self.port = 8000
self.config = config
self.dir = config.directory
script_path = os.path.join(self.dir, "launch", config.script_name)
if not os.path.exists(script_path):
raise FileNotFoundError(f"vLLM script not found: {script_path}")
command = ["bash", script_path]
if config.args:
command.extend(config.args)
super().__init__(
command=command,
timeout=config.timeout,
display_output=True,
working_dir=self.dir,
health_check_ports=[], # Disable port health check
health_check_urls=[
(f"http://localhost:{self.port}/v1/models", self._check_models_api)
],
delayed_start=config.delayed_start,
terminate_existing=False, # If true, will call all bash processes including myself
stragglers=[], # Don't kill any stragglers automatically
log_dir=request.node.name,
)
vllm_dir = os.environ.get("VLLM_DIR", "/workspace/components/backends/vllm")
# vLLM test configurations # vLLM test configurations
vllm_configs = { vllm_configs = {
"aggregated": VLLMConfig( "aggregated": VLLMConfig(
name="aggregated", name="aggregated",
directory="/workspace/components/backends/vllm", directory=vllm_dir,
script_name="agg.sh", script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.vllm], marks=[pytest.mark.gpu_1],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload_default(),
completion_payload_default(),
metric_payload_default(min_num_requests=6),
],
), ),
"agg-router": VLLMConfig( "agg-router": VLLMConfig(
name="agg-router", name="agg-router",
directory="/workspace/components/backends/vllm", directory=vllm_dir,
script_name="agg_router.sh", script_name="agg_router.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm], marks=[pytest.mark.gpu_2],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload_default(
expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
]
)
],
env={
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
},
), ),
"disaggregated": VLLMConfig( "disaggregated": VLLMConfig(
name="disaggregated", name="disaggregated",
directory="/workspace/components/backends/vllm", directory=vllm_dir,
script_name="disagg.sh", script_name="disagg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm], marks=[pytest.mark.gpu_2],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload_default(),
completion_payload_default(),
],
), ),
"deepep": VLLMConfig( "deepep": VLLMConfig(
name="deepep", name="deepep",
directory="/workspace/components/backends/vllm", directory=vllm_dir,
script_name="dsr1_dep.sh", script_name="dsr1_dep.sh",
marks=[ marks=[
pytest.mark.gpu_2, pytest.mark.gpu_2,
pytest.mark.vllm, pytest.mark.vllm,
pytest.mark.h100, pytest.mark.h100,
], ],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="deepseek-ai/DeepSeek-V2-Lite", model="deepseek-ai/DeepSeek-V2-Lite",
args=[ script_args=[
"--model", "--model",
"deepseek-ai/DeepSeek-V2-Lite", "deepseek-ai/DeepSeek-V2-Lite",
"--num-nodes", "--num-nodes",
...@@ -184,46 +93,84 @@ vllm_configs = { ...@@ -184,46 +93,84 @@ vllm_configs = {
"2", "2",
], ],
timeout=700, timeout=700,
request_payloads=[
chat_payload_default(),
completion_payload_default(),
],
), ),
"multimodal_agg_llava": VLLMConfig( "multimodal_agg_llava": VLLMConfig(
name="multimodal_agg_llava", name="multimodal_agg_llava",
directory="/workspace/examples/multimodal", directory="/workspace/examples/multimodal",
script_name="agg.sh", script_name="agg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm], marks=[pytest.mark.gpu_2],
endpoints=["v1/chat/completions"],
response_handlers=[
chat_completions_response_handler,
],
model="llava-hf/llava-1.5-7b-hf", model="llava-hf/llava-1.5-7b-hf",
args=["--model", "llava-hf/llava-1.5-7b-hf"], script_args=["--model", "llava-hf/llava-1.5-7b-hf"],
request_payloads=[
chat_payload(
[
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
},
},
],
repeat_count=1,
expected_response=["bus"],
temperature=0.0,
)
],
), ),
"multimodal_agg_qwen": VLLMConfig( "multimodal_agg_qwen": VLLMConfig(
name="multimodal_agg_qwen", name="multimodal_agg_qwen",
directory="/workspace/examples/multimodal", directory="/workspace/examples/multimodal",
script_name="agg.sh", script_name="agg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm], marks=[pytest.mark.gpu_2],
endpoints=["v1/chat/completions"],
response_handlers=[
chat_completions_response_handler,
],
model="Qwen/Qwen2.5-VL-7B-Instruct", model="Qwen/Qwen2.5-VL-7B-Instruct",
delayed_start=0, delayed_start=0,
args=["--model", "Qwen/Qwen2.5-VL-7B-Instruct"], script_args=["--model", "Qwen/Qwen2.5-VL-7B-Instruct"],
timeout=360, timeout=360,
request_payloads=[
chat_payload(
[
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
},
},
],
repeat_count=1,
expected_response=["bus"],
)
],
), ),
"multimodal_video_agg": VLLMConfig( "multimodal_video_agg": VLLMConfig(
name="multimodal_video_agg", name="multimodal_video_agg",
directory="/workspace/examples/multimodal", directory="/workspace/examples/multimodal",
script_name="video_agg.sh", script_name="video_agg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm], marks=[pytest.mark.gpu_2],
endpoints=["v1/chat/completions"],
response_handlers=[
chat_completions_response_handler,
],
model="llava-hf/LLaVA-NeXT-Video-7B-hf", model="llava-hf/LLaVA-NeXT-Video-7B-hf",
delayed_start=0, delayed_start=0,
args=["--model", "llava-hf/LLaVA-NeXT-Video-7B-hf"], script_args=["--model", "llava-hf/LLaVA-NeXT-Video-7B-hf"],
timeout=360, timeout=360,
request_payloads=[
chat_payload(
[
{"type": "text", "text": "Describe the video in detail"},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
},
},
],
repeat_count=1,
expected_response=["rabbit"],
)
],
), ),
# TODO: Enable this test case when we have 4 GPUs runners. # TODO: Enable this test case when we have 4 GPUs runners.
# "multimodal_disagg": VLLMConfig( # "multimodal_disagg": VLLMConfig(
...@@ -231,13 +178,9 @@ vllm_configs = { ...@@ -231,13 +178,9 @@ vllm_configs = {
# directory="/workspace/examples/multimodal", # directory="/workspace/examples/multimodal",
# script_name="disagg.sh", # script_name="disagg.sh",
# marks=[pytest.mark.gpu_4, pytest.mark.vllm], # marks=[pytest.mark.gpu_4, pytest.mark.vllm],
# endpoints=["v1/chat/completions"],
# response_handlers=[
# chat_completions_response_handler,
# ],
# model="llava-hf/llava-1.5-7b-hf", # model="llava-hf/llava-1.5-7b-hf",
# delayed_start=45, # delayed_start=45,
# args=["--model", "llava-hf/llava-1.5-7b-hf"], # script_args=["--model", "llava-hf/llava-1.5-7b-hf"],
# ), # ),
} }
...@@ -253,41 +196,11 @@ def vllm_config_test(request): ...@@ -253,41 +196,11 @@ def vllm_config_test(request):
return vllm_configs[request.param] return vllm_configs[request.param]
@pytest.mark.vllm
@pytest.mark.e2e @pytest.mark.e2e
def test_serve_deployment(vllm_config_test, request, runtime_services): def test_serve_deployment(vllm_config_test, request, runtime_services):
""" """
Test dynamo serve deployments with different graph configurations. Test dynamo serve deployments with different graph configurations.
""" """
# runtime_services is used to start nats and etcd
logger = logging.getLogger(request.node.name)
logger.info("Starting test_deployment")
config = vllm_config_test config = vllm_config_test
payload = create_payload_for_config(config) run_serve_deployment(config, request)
logger.info("Using model: %s", config.model)
logger.info("Script: %s", config.script_name)
with VLLMProcess(config, request) as server_process:
for endpoint, response_handler in zip(
config.endpoints, config.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time()
elapsed = 0.0
request_body = (
payload.payload_chat
if endpoint == "v1/chat/completions"
else payload.payload_completions
)
for _ in range(payload.repeat_count):
elapsed = time.time() - start_time
response = server_process.send_request(
url, payload=request_body, timeout=config.timeout - elapsed
)
server_process.check_response(payload, response, response_handler)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import time
from typing import Any, Dict
import requests
logger = logging.getLogger(__name__)
def send_request(
url: str,
payload: Dict[str, Any],
timeout: float = 30.0,
method: str = "POST",
log_level: int = 20,
) -> requests.Response:
"""
Send an HTTP request to the engine with detailed logging.
Args:
url: The endpoint URL
payload: The request payload (for GET, sent as query params)
timeout: Request timeout in seconds
method: HTTP method ("POST" or "GET")
Returns:
The response object
Raises:
requests.RequestException: If the request fails
"""
method_upper = method.upper()
payload_json = json.dumps(payload, indent=2)
curl_command = ""
if method_upper == "GET":
curl_command = f'curl "{url}"'
if payload:
# For GET requests, payload is sent as query parameters
query_params = "&".join(f"{k}={v}" for k, v in payload.items())
curl_command += f"?{query_params}"
else:
curl_command = f'curl -X {method_upper} "{url}"'
if method_upper == "POST":
curl_command += (
' \\\n -H "Content-Type: application/json" \\\n -d \''
+ payload_json
+ "'"
)
logger.log(log_level, "Sending request (curl equivalent):\n%s", curl_command)
start_time = time.time()
try:
if method_upper == "GET":
response = requests.get(url, params=payload, timeout=timeout)
elif method_upper == "POST":
response = requests.post(url, json=payload, timeout=timeout)
else:
# Fallback for other methods if needed
response = requests.request(
method_upper, url, json=payload, timeout=timeout
)
elapsed = time.time() - start_time
# Log response details
logger.log(
log_level,
"Received response: status=%d, elapsed=%.2fs",
response.status_code,
elapsed,
)
logger.debug("Response headers: %s", dict(response.headers))
# Try to log response body (truncated if too long)
try:
if response.headers.get("content-type", "").startswith("application/json"):
response_data = response.json()
response_str = json.dumps(response_data, indent=2)
if len(response_str) > 1000:
response_str = response_str[:1000] + "... (truncated)"
logger.debug("Response body: %s", response_str)
else:
response_text = response.text
if len(response_text) > 1000:
response_text = response_text[:1000] + "... (truncated)"
logger.debug("Response body: %s", response_text)
except Exception as e:
logger.debug("Could not parse response body: %s", e)
return response
except requests.exceptions.Timeout:
logger.error("Request timed out after %.2f seconds", timeout)
raise
except requests.exceptions.ConnectionError as e:
logger.error("Connection error: %s", e)
raise
except requests.exceptions.RequestException as e:
logger.error("Request failed: %s", e)
raise
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@dataclass
class Payload:
"""
Represents a test payload with expected response and log patterns.
"""
payload_chat: Dict[str, Any]
expected_response: List[str]
expected_log: List[str]
repeat_count: int = 1
payload_completions: Optional[Dict[str, Any]] = None
def chat_completions_response_handler(response):
"""
Process chat completions API responses.
"""
if response.status_code != 200:
return ""
result = response.json()
assert "choices" in result, "Missing 'choices' in response"
assert len(result["choices"]) > 0, "Empty choices in response"
assert "message" in result["choices"][0], "Missing 'message' in first choice"
message = result["choices"][0]["message"]
# Check for content in all possible fields where parsers might put output:
# 1. content - standard message content
# 2. reasoning_content - for models with reasoning parsers
# 3. refusal - when the model refuses to answer
# 4. tool_calls - for function/tool calling responses
content = message.get("content", "")
reasoning_content = message.get("reasoning_content", "")
refusal = message.get("refusal", "")
# Check for tool calls
tool_calls = message.get("tool_calls", [])
tool_content = ""
if tool_calls:
# Extract content from tool calls
tool_content = ", ".join(
call.get("function", {}).get("arguments", "")
for call in tool_calls
if call.get("function", {}).get("arguments")
)
# Return the first non-empty field in priority order
for field_content in [content, reasoning_content, refusal, tool_content]:
if field_content:
return field_content
# If all fields are empty, provide a detailed error
raise ValueError(
"All possible content fields are empty in message. "
f"Checked: content={repr(content)}, reasoning_content={repr(reasoning_content)}, "
f"refusal={repr(refusal)}, tool_calls={tool_calls}"
)
def completions_response_handler(response):
"""
Process completions API responses.
"""
if response.status_code != 200:
return ""
result = response.json()
assert "choices" in result, "Missing 'choices' in response"
assert len(result["choices"]) > 0, "Empty choices in response"
assert "text" in result["choices"][0], "Missing 'text' in first choice"
return result["choices"][0]["text"]
def metrics_handler(response):
"""Handler to check if metrics endpoint is working and contains model label."""
if response.status_code != 200:
raise AssertionError(
f"Metrics endpoint returned non-200 status code: {response.status_code}"
)
metrics_text = response.text
# Check for any model label in dynamo_component_requests_total metric
pattern = r'dynamo_component_requests_total\{[^}]*model="[^"]*"[^}]*\}\s+(\d+)'
matches = re.findall(pattern, metrics_text)
if not matches:
raise AssertionError(
"Metric 'dynamo_component_requests_total' with model label not found in metrics output"
)
# Since we send a request first, the counter should be > 0
for match in matches:
request_count = int(match)
if request_count > 0:
logger.info(
f"Found dynamo_component_requests_total with count: {request_count}"
)
return metrics_text
raise AssertionError(
"dynamo_component_requests_total exists but has count of 0 - request was not tracked"
)
...@@ -3,15 +3,19 @@ ...@@ -3,15 +3,19 @@
import json import json
import logging import logging
import time import os
from typing import Any, Callable, Dict from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import requests import requests
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import BasePayload, check_health_generate, check_models_api
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
FRONTEND_PORT = 8000
class EngineResponseError(Exception): class EngineResponseError(Exception):
"""Custom exception for engine response errors""" """Custom exception for engine response errors"""
...@@ -19,107 +23,38 @@ class EngineResponseError(Exception): ...@@ -19,107 +23,38 @@ class EngineResponseError(Exception):
pass pass
class EngineProcess(ManagedProcess): class EngineLogError(Exception):
"""Base class for LLM engine processes (vLLM, TRT-LLM, etc.)""" """Custom exception for engine log validation errors"""
def _check_models_api(self, response):
"""Check if models API is working and returns models"""
try:
if response.status_code != 200:
return False
data = response.json()
return data.get("data") and len(data["data"]) > 0
except Exception:
return False
def get_metrics(self, port=8081):
"""Curl the metrics endpoint and return the response."""
metrics_url = f"http://localhost:{port}/metrics"
logger.info(f"Curling metrics endpoint: {metrics_url}")
try:
response = requests.get(metrics_url, timeout=10)
logger.info(
f"Metrics endpoint responded with status: {response.status_code}"
)
return response
except requests.RequestException as e:
logger.error(f"Failed to curl metrics endpoint: {e}")
raise
def send_request(
self, url: str, payload: Dict[str, Any], timeout: float = 30.0
) -> requests.Response:
"""
Send a POST request to the engine with detailed logging.
Args: pass
url: The endpoint URL
payload: The request payload
timeout: Request timeout in seconds
Returns:
The response object
Raises:
requests.RequestException: If the request fails
"""
# Log the request as a curl command for easy reproduction @dataclass
payload_json = json.dumps(payload, indent=2) class EngineConfig:
curl_command = f'curl -X POST "{url}" \\\n -H "Content-Type: application/json" \\\n -d \'{payload_json}\'' """Base configuration for engine test scenarios"""
logger.info("Sending request (curl equivalent):\n%s", curl_command)
start_time = time.time() name: str
try: directory: str
response = requests.post(url, json=payload, timeout=timeout) script_name: str
elapsed = time.time() - start_time marks: List[Any]
request_payloads: List[BasePayload]
model: str
# Log response details script_args: Optional[List[str]] = None
logger.info( models_port: int = 8000
"Received response: status=%d, elapsed=%.2fs", timeout: int = 600
response.status_code, delayed_start: int = 0
elapsed, env: Dict[str, str] = field(default_factory=dict)
) stragglers: list[str] = field(default_factory=list)
logger.debug("Response headers: %s", dict(response.headers))
# Try to log response body (truncated if too long) class EngineProcess(ManagedProcess):
try: """Base class for LLM engine processes (vLLM, TRT-LLM, etc.)"""
if response.headers.get("content-type", "").startswith(
"application/json"
):
response_data = response.json()
response_str = json.dumps(response_data, indent=2)
if len(response_str) > 1000:
response_str = response_str[:1000] + "... (truncated)"
logger.debug("Response body: %s", response_str)
else:
response_text = response.text
if len(response_text) > 1000:
response_text = response_text[:1000] + "... (truncated)"
logger.debug("Response body: %s", response_text)
except Exception as e:
logger.debug("Could not parse response body: %s", e)
return response
except requests.exceptions.Timeout:
logger.error("Request timed out after %.2f seconds", timeout)
raise
except requests.exceptions.ConnectionError as e:
logger.error("Connection error: %s", e)
raise
except requests.exceptions.RequestException as e:
logger.error("Request failed: %s", e)
raise
def check_response( def check_response(
self, self,
payload: Any, payload: BasePayload,
response: requests.Response, response: requests.Response,
response_handler: Callable[[Any], str],
) -> None: ) -> None:
""" """
Check if the response is valid and contains expected content. Check if the response is valid and contains expected content.
...@@ -151,30 +86,93 @@ class EngineProcess(ManagedProcess): ...@@ -151,30 +86,93 @@ class EngineProcess(ManagedProcess):
raise EngineResponseError(error_msg) raise EngineResponseError(error_msg)
# Extract content using the handler
try: try:
content = response_handler(response) content = payload.process_response(response)
logger.info( logger.info(
"Extracted content: \n%s", "Extracted content: \n%s",
content[:200] + "..." if len(content) > 200 else content, content[:200] + "..."
if isinstance(content, str) and len(content) > 200
else content,
) )
except AssertionError as e:
raise EngineResponseError(str(e))
except Exception as e: except Exception as e:
raise EngineResponseError(f"Failed to extract content from response: {e}") raise EngineResponseError(f"Failed to handle response: {e}")
# Optionally validate expected log patterns after response handling
if payload.expected_log:
self.validate_expected_logs(payload.expected_log)
def validate_expected_logs(self, patterns: Any) -> None:
"""Validate that all regex patterns are present in the current logs.
Reads the full log via ManagedProcess.read_logs and searches for each
provided regex pattern. Raises EngineLogError if any are missing.
"""
import re # local import to keep module load minimal
content = self.read_logs() or ""
if not content: if not content:
raise EngineResponseError("Response contained empty content") raise EngineLogError(
f"Log file not available or empty at path: {self.log_path}"
)
if hasattr(payload, "expected_response") and payload.expected_response: compiled = [re.compile(p) for p in patterns]
missing_expected = [] missing = []
for expected in payload.expected_response: for pattern, rx in zip(patterns, compiled):
if expected not in content: if not rx.search(content):
missing_expected.append(expected) missing.append(pattern)
if missing_expected: if missing:
raise EngineResponseError( sample = content[-1000:] if len(content) > 1000 else content
f"Expected content not found in response. Missing: {missing_expected}" raise EngineLogError(
) f"Missing expected log patterns: {missing}\n\nLog sample:\n{sample}"
else: )
logger.info( logger.info(f"SUCCESS: All expected log patterns: {patterns} found")
f"SUCCESS: All expected content ({payload.expected_response}) found in response"
) @classmethod
def from_script(
cls,
config: EngineConfig,
request: Any,
extra_env: Optional[Dict[str, str]] = None,
) -> "EngineProcess":
"""Factory to create an EngineProcess configured to run a launch script."""
assert isinstance(config, EngineConfig), "Must use an instance of EngineConfig"
directory = config.directory
script_path = os.path.join(directory, "launch", config.script_name)
if not os.path.exists(script_path):
raise FileNotFoundError(f"Script not found: {script_path}")
command: List[str] = ["bash", script_path]
if config.script_args:
command.extend(config.script_args)
env = os.environ.copy()
if getattr(config, "env", None):
env.update(config.env)
if extra_env:
env.update(extra_env)
return cls(
command=command,
env=env,
timeout=config.timeout,
display_output=True,
working_dir=directory,
health_check_ports=[],
health_check_urls=[
(f"http://localhost:{config.models_port}/v1/models", check_models_api),
(
f"http://localhost:{config.models_port}/health",
check_health_generate,
),
],
delayed_start=config.delayed_start,
terminate_existing=False,
stragglers=config.stragglers,
log_dir=request.node.name,
)
...@@ -73,6 +73,7 @@ class ManagedProcess: ...@@ -73,6 +73,7 @@ class ManagedProcess:
env: Optional[dict] = None env: Optional[dict] = None
health_check_ports: List[int] = field(default_factory=list) health_check_ports: List[int] = field(default_factory=list)
health_check_urls: List[Any] = field(default_factory=list) health_check_urls: List[Any] = field(default_factory=list)
health_check_funcs: List[Any] = field(default_factory=list)
delayed_start: int = 0 delayed_start: int = 0
timeout: int = 300 timeout: int = 300
working_dir: Optional[str] = None working_dir: Optional[str] = None
...@@ -93,6 +94,24 @@ class ManagedProcess: ...@@ -93,6 +94,24 @@ class ManagedProcess:
_tee_proc = None _tee_proc = None
_sed_proc = None _sed_proc = None
@property
def log_path(self):
"""Return the absolute path to the process log file if available."""
return self._log_path
def read_logs(self) -> str:
"""Read and return the entire contents of the process log file.
Returns an empty string if the log file is not yet available.
"""
try:
if self._log_path and os.path.exists(self._log_path):
with open(self._log_path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
except Exception as e:
self._logger.warning("Could not read log file %s: %s", self._log_path, e)
return ""
def __enter__(self): def __enter__(self):
try: try:
self._logger = logging.getLogger(self.__class__.__name__) self._logger = logging.getLogger(self.__class__.__name__)
...@@ -109,6 +128,7 @@ class ManagedProcess: ...@@ -109,6 +128,7 @@ class ManagedProcess:
time.sleep(self.delayed_start) time.sleep(self.delayed_start)
elapsed = self._check_ports(self.timeout) elapsed = self._check_ports(self.timeout)
self._check_urls(self.timeout - elapsed) self._check_urls(self.timeout - elapsed)
self._check_funcs(self.timeout - elapsed)
return self return self
...@@ -121,44 +141,73 @@ class ManagedProcess: ...@@ -121,44 +141,73 @@ class ManagedProcess:
) )
raise raise
def __exit__(self, exc_type, exc_val, exc_tb): def _cleanup_stragglers(self):
self._terminate_process_group() """Clean up straggler processes - called during exit and signal handling"""
try:
if self.stragglers or self.straggler_commands:
self._logger.info(
"Checking for straggler processes: stragglers=%s, straggler_commands=%s",
self.stragglers,
self.straggler_commands,
)
process_list = [self.proc, self._tee_proc, self._sed_proc] for ps_process in psutil.process_iter(["name", "cmdline"]):
for process in process_list:
if process:
try: try:
if process.stdout: process_name = ps_process.name()
process.stdout.close() if process_name in self.stragglers:
if process.stdin:
process.stdin.close()
terminate_process_tree(process.pid, self._logger)
process.wait()
except Exception as e:
self._logger.warning("Error terminating process: %s", e)
if self.data_dir:
self._remove_directory(self.data_dir)
for ps_process in psutil.process_iter(["name", "cmdline"]):
try:
if ps_process.name() in self.stragglers:
self._logger.info(
"Terminating Straggler %s %s", ps_process.name(), ps_process.pid
)
terminate_process_tree(ps_process.pid, self._logger)
for cmdline in self.straggler_commands:
if cmdline in " ".join(ps_process.cmdline()):
self._logger.info( self._logger.info(
"Terminating Straggler Cmdline %s %s %s", "Terminating Straggler %s %s", process_name, ps_process.pid
ps_process.name(),
ps_process.pid,
cmdline,
) )
terminate_process_tree(ps_process.pid, self._logger) terminate_process_tree(ps_process.pid, self._logger)
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
# Process may have terminated or become inaccessible during iteration # Check command line arguments
pass cmdline = ps_process.cmdline()
cmdline_str = " ".join(cmdline) if cmdline else ""
for straggler_cmd in self.straggler_commands:
if straggler_cmd in cmdline_str:
self._logger.info(
"Terminating Straggler Cmdline %s %s %s",
process_name,
ps_process.pid,
straggler_cmd,
)
terminate_process_tree(ps_process.pid, self._logger)
break # Avoid terminating the same process multiple times
except (
psutil.NoSuchProcess,
psutil.AccessDenied,
psutil.ZombieProcess,
):
# Process may have terminated or become inaccessible during iteration
pass
except Exception as e:
# Catch any other unexpected errors to ensure cleanup continues
self._logger.warning("Error checking process: %s", e)
except Exception as e:
# Ensure that any error in straggler cleanup doesn't prevent other cleanup
self._logger.error("Error during straggler cleanup: %s", e)
def __exit__(self, exc_type, exc_val, exc_tb):
try:
self._terminate_process_group()
process_list = [self.proc, self._tee_proc, self._sed_proc]
for process in process_list:
if process:
try:
if process.stdout:
process.stdout.close()
if process.stdin:
process.stdin.close()
terminate_process_tree(process.pid, self._logger)
process.wait()
except Exception as e:
self._logger.warning("Error terminating process: %s", e)
if self.data_dir:
self._remove_directory(self.data_dir)
finally:
# Always run straggler cleanup, even if interrupted
self._cleanup_stragglers()
def _start_process(self): def _start_process(self):
assert self._command_name assert self._command_name
...@@ -327,7 +376,7 @@ class ManagedProcess: ...@@ -327,7 +376,7 @@ class ManagedProcess:
elapsed += self._check_url(url, timeout - elapsed) elapsed += self._check_url(url, timeout - elapsed)
return elapsed return elapsed
def _check_url(self, url, timeout=30, sleep=1, log_interval=10): def _check_url(self, url, timeout=30, sleep=1, log_interval=20):
if isinstance(url, tuple): if isinstance(url, tuple):
response_check = url[1] response_check = url[1]
url = url[0] url = url[0]
...@@ -403,6 +452,70 @@ class ManagedProcess: ...@@ -403,6 +452,70 @@ class ManagedProcess:
) )
raise RuntimeError("FAILED: Check URL: %s" % url) raise RuntimeError("FAILED: Check URL: %s" % url)
def _check_funcs(self, timeout):
elapsed = 0.0
for func in self.health_check_funcs:
elapsed += self._check_func(func, timeout - elapsed)
return elapsed
def _check_func(self, func, timeout=30, sleep=1, log_interval=20):
start_time = time.time()
func_name = getattr(func, "__name__", str(func))
self._logger.info("Running custom health check '%s'", func_name)
elapsed = 0.0
attempt = 0
last_log_time = 0.0
while elapsed < timeout:
self._check_process_alive("while waiting for health check")
attempt += 1
check_failed = False
failure_reason = None
try:
# Prefer functions that accept remaining timeout; fall back to no-arg call
try:
result = func(timeout - elapsed)
except TypeError:
result = func()
if bool(result):
self._logger.info(
"SUCCESS: Custom health check '%s' passed (attempt=%d, elapsed=%.1fs)",
func_name,
attempt,
elapsed,
)
return time.time() - start_time
else:
check_failed = True
failure_reason = "returned False"
except Exception as e:
check_failed = True
failure_reason = f"exception: {e}"
if check_failed and elapsed - last_log_time >= log_interval:
self._logger.info(
"Still waiting on custom health check '%s' (%s) (attempt=%d, elapsed=%.1fs)",
func_name,
failure_reason,
attempt,
elapsed,
)
last_log_time = elapsed
time.sleep(sleep)
elapsed = time.time() - start_time
self._logger.error(
"FAILED: Custom health check '%s' (attempts=%d, elapsed=%.1fs)",
func_name,
attempt,
elapsed,
)
raise RuntimeError("FAILED: Custom health check")
def _terminate_existing(self): def _terminate_existing(self):
if self.terminate_existing: if self.terminate_existing:
for proc in psutil.process_iter(["name", "cmdline"]): for proc in psutil.process_iter(["name", "cmdline"]):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Union
from tests.utils.client import send_request
from tests.utils.payloads import ChatPayload, CompletionPayload, MetricsPayload
# Common default text prompt used across tests
TEXT_PROMPT = "Tell me a short joke about AI."
def chat_payload_default(
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
max_tokens: int = 150,
temperature: float = 0.1,
stream: bool = False,
) -> ChatPayload:
return ChatPayload(
body={
"messages": [
{
"role": "user",
"content": TEXT_PROMPT,
}
],
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or ["AI"],
)
def completion_payload_default(
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
max_tokens: int = 150,
temperature: float = 0.1,
stream: bool = False,
) -> CompletionPayload:
return CompletionPayload(
body={
"prompt": TEXT_PROMPT,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or ["AI"],
)
def metric_payload_default(
min_num_requests: int,
repeat_count: int = 1,
expected_log: Optional[List[str]] = None,
) -> MetricsPayload:
return MetricsPayload(
body={},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=[],
min_num_requests=min_num_requests,
)
def chat_payload(
content: Union[str, List[Dict[str, Any]]],
repeat_count: int = 1,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
max_tokens: int = 300,
temperature: Optional[float] = None,
stream: bool = False,
) -> ChatPayload:
body: Dict[str, Any] = {
"messages": [
{
"role": "user",
"content": content,
}
],
"max_tokens": max_tokens,
"stream": stream,
}
if temperature is not None:
body["temperature"] = temperature
return ChatPayload(
body=body,
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or [],
)
def completion_payload(
prompt: str,
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
max_tokens: int = 150,
temperature: float = 0.1,
stream: bool = False,
) -> CompletionPayload:
return CompletionPayload(
body={
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or [],
)
# Build small request-based health checks for chat and completions
# these should only be used as a last resort. Generally want to use an actual health check
def make_chat_health_check(port: int, model: str):
def _check_chat_endpoint(remaining_timeout: float = 30.0) -> bool:
payload = chat_payload_default(
repeat_count=1,
expected_response=[],
max_tokens=8,
temperature=0.0,
stream=False,
).with_model(model)
payload.port = port
try:
resp = send_request(
payload.url(),
payload.body,
timeout=min(max(1.0, remaining_timeout), 5.0),
method=payload.method,
log_level=10,
)
# Validate structure only; expected_response is empty
_ = payload.response_handler(resp)
return True
except Exception:
return False
return _check_chat_endpoint
def make_completions_health_check(port: int, model: str):
def _check_completions_endpoint(remaining_timeout: float = 30.0) -> bool:
payload = completion_payload_default(
repeat_count=1,
expected_response=[],
max_tokens=8,
temperature=0.0,
stream=False,
).with_model(model)
payload.port = port
try:
resp = send_request(
payload.url(),
payload.body,
timeout=min(max(1.0, remaining_timeout), 5.0),
method=payload.method,
log_level=10,
)
out = payload.response_handler(resp)
if not out:
raise ValueError("")
return True
except Exception:
return False
return _check_completions_endpoint
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List
logger = logging.getLogger(__name__)
@dataclass
class BasePayload:
"""Generic payload body plus expectations and repeat count."""
body: Dict[str, Any]
expected_response: List[str]
expected_log: List[str]
repeat_count: int = 1
timeout: int = 30
# Connection info
host: str = "localhost"
port: int = 8000
endpoint: str = ""
method: str = "POST"
def url(self) -> str:
ep = self.endpoint.lstrip("/")
return f"http://{self.host}:{self.port}/{ep}"
def with_model(self, model):
p = deepcopy(self)
if "model" not in p.body:
p.body = {**p.body, "model": model}
return p
def response_handler(self, response: Any) -> str:
"""Extract a text representation of the response for logging/validation."""
raise NotImplementedError("Subclasses must implement response_handler()")
def validate(self, response: Any, content: str) -> None:
"""Default validation: ensure expected substrings appear in content."""
if self.expected_response:
missing_expected = []
for expected in self.expected_response:
if not content or expected not in content:
missing_expected.append(expected)
if missing_expected:
raise AssertionError(
f"Expected content not found in response. Missing: {missing_expected}"
)
logger.info(f"SUCCESS: All expected_responses: {self.expected_response} found.")
def process_response(self, response: Any) -> str:
"""Convenience: run response_handler then validate; return content."""
content = self.response_handler(response)
self.validate(response, content)
return content
@dataclass
class ChatPayload(BasePayload):
"""Payload for chat completions endpoint."""
endpoint: str = "/v1/chat/completions"
@staticmethod
def extract_content(response):
"""
Process chat completions API responses.
"""
response.raise_for_status()
result = response.json()
assert "choices" in result, "Missing 'choices' in response"
assert len(result["choices"]) > 0, "Empty choices in response"
assert "message" in result["choices"][0], "Missing 'message' in first choice"
# Check for content in all possible fields where parsers might put output:
# 1. content - standard message content
# 2. reasoning_content - for models with reasoning parsers
# 3. refusal - when the model refuses to answer
# 4. tool_calls - for function/tool calling responses
message = result["choices"][0]["message"]
content = message.get("content", "")
reasoning_content = message.get("reasoning_content", "")
refusal = message.get("refusal", "")
tool_calls = message.get("tool_calls", [])
tool_content = ""
if tool_calls:
tool_content = ", ".join(
call.get("function", {}).get("arguments", "")
for call in tool_calls
if call.get("function", {}).get("arguments")
)
for field_content in [content, reasoning_content, refusal, tool_content]:
if field_content:
return field_content
raise ValueError(
"All possible content fields are empty in message. "
f"Checked: content={repr(content)}, reasoning_content={repr(reasoning_content)}, "
f"refusal={repr(refusal)}, tool_calls={tool_calls}"
)
def response_handler(self, response: Any) -> str:
return ChatPayload.extract_content(response)
@dataclass
class CompletionPayload(BasePayload):
"""Payload for completions endpoint."""
endpoint: str = "/v1/completions"
@staticmethod
def extract_text(response):
"""
Process completions API responses.
"""
response.raise_for_status()
result = response.json()
assert "choices" in result, "Missing 'choices' in response"
assert len(result["choices"]) > 0, "Empty choices in response"
assert "text" in result["choices"][0], "Missing 'text' in first choice"
return result["choices"][0]["text"]
def response_handler(self, response: Any) -> str:
return CompletionPayload.extract_text(response)
@dataclass
class MetricsPayload(BasePayload):
endpoint: str = "/metrics"
method: str = "GET"
port: int = 8081
min_num_requests: int = 1
def with_model(self, model):
# Metrics does not use model in request body
return self
def response_handler(self, response: Any) -> str:
response.raise_for_status()
return response.text
def validate(self, response: Any, content: str) -> None:
pattern = r'dynamo_component_requests_total\{[^}]*model="[^"]*"[^}]*\}\s+(\d+)'
matches = re.findall(pattern, content)
if not matches:
raise AssertionError(
"Metric 'dynamo_component_requests_total' with model label not found in metrics output"
)
for match in matches:
request_count = int(match)
if request_count >= self.min_num_requests:
logger.info(
f"SUCCESS: Found dynamo_component_requests_total with count: {request_count}"
)
return
raise AssertionError(
f"dynamo_component_requests_total exists but has count {request_count} which is less than required {self.min_num_requests}"
)
def check_models_api(response):
"""Check if models API is working and returns models"""
try:
if response.status_code != 200:
return False
data = response.json()
return data.get("data") and len(data["data"]) > 0
except Exception:
return False
# Additional health check helpers
def check_health_generate(response):
"""Validate /health reports a 'generate' endpoint.
Returns True if either of the following is found:
- "endpoints" contains a string mentioning 'generate'
- "instances" contains an object with endpoint == 'generate'
"""
try:
if response.status_code != 200:
return False
data = response.json()
# Check endpoints list for any entry containing 'generate'
endpoints = data.get("endpoints", []) or []
for ep in endpoints:
if isinstance(ep, str) and "generate" in ep:
return True
# Check instances for an entry with endpoint == 'generate'
instances = data.get("instances", []) or []
for inst in instances:
if isinstance(inst, dict) and inst.get("endpoint") == "generate":
return True
return False
except Exception:
return False
# backwards compatiability
def completions_response_handler(response):
return CompletionPayload.extract_text(response)
def chat_completions_response_handler(response):
return ChatPayload.extract_content(response)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment