Unverified Commit cdddaeda authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

test: Add dynamo serve TRTLLM example to pytest (#1417)

parent 4de7f44c
...@@ -32,6 +32,22 @@ logging.basicConfig( ...@@ -32,6 +32,22 @@ logging.basicConfig(
) )
def pytest_collection_modifyitems(config, items):
"""
This function is called to modify the list of tests to run.
It is used to skip tests that are not supported on all environments.
"""
# Tests marked with tensorrtllm requires specific environment with tensorrtllm
# installed. Hence, we skip them if the user did not explicitly ask for them.
if config.getoption("-m") and "tensorrtllm" in config.getoption("-m"):
return
skip_tensorrtllm = pytest.mark.skip(reason="need -m tensorrtllm to run")
for item in items:
if "tensorrtllm" in item.keywords:
item.add_marker(skip_tensorrtllm)
class EtcdServer(ManagedProcess): class EtcdServer(ManagedProcess):
def __init__(self, request, port=2379, timeout=300): def __init__(self, request, port=2379, timeout=300):
port_string = str(port) port_string = str(port)
......
...@@ -24,6 +24,7 @@ import requests ...@@ -24,6 +24,7 @@ import requests
from tests.utils.deployment_graph import ( from tests.utils.deployment_graph import (
DeploymentGraph, DeploymentGraph,
Payload, Payload,
chat_completions_response_handler,
completions_response_handler, completions_response_handler,
) )
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
...@@ -31,7 +32,7 @@ from tests.utils.managed_process import ManagedProcess ...@@ -31,7 +32,7 @@ from tests.utils.managed_process import ManagedProcess
text_prompt = "Tell me a short joke about AI." text_prompt = "Tell me a short joke about AI."
multimodal_payload = Payload( multimodal_payload = Payload(
payload={ payload_chat={
"model": "llava-hf/llava-1.5-7b-hf", "model": "llava-hf/llava-1.5-7b-hf",
"messages": [ "messages": [
{ {
...@@ -50,12 +51,13 @@ multimodal_payload = Payload( ...@@ -50,12 +51,13 @@ multimodal_payload = Payload(
"max_tokens": 300, # Reduced from 500 "max_tokens": 300, # Reduced from 500
"stream": False, "stream": False,
}, },
repeat_count=1,
expected_log=[], expected_log=[],
expected_response=["bus"], expected_response=["bus"],
) )
text_payload = Payload( text_payload = Payload(
payload={ payload_chat={
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [ "messages": [
{ {
...@@ -67,6 +69,14 @@ text_payload = Payload( ...@@ -67,6 +69,14 @@ text_payload = Payload(
"temperature": 0.1, "temperature": 0.1,
"seed": 0, "seed": 0,
}, },
payload_completions={
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"prompt": text_prompt,
"max_tokens": 150,
"temperature": 0.1,
"seed": 0,
},
repeat_count=10,
expected_log=[], expected_log=[],
expected_response=["AI"], expected_response=["AI"],
) )
...@@ -77,8 +87,11 @@ deployment_graphs = { ...@@ -77,8 +87,11 @@ deployment_graphs = {
module="graphs.agg:Frontend", module="graphs.agg:Frontend",
config="configs/agg.yaml", config="configs/agg.yaml",
directory="/workspace/examples/llm", directory="/workspace/examples/llm",
endpoint="v1/chat/completions", endpoints=["v1/chat/completions", "v1/completions"],
response_handler=completions_response_handler, response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_1, pytest.mark.vllm], marks=[pytest.mark.gpu_1, pytest.mark.vllm],
), ),
text_payload, text_payload,
...@@ -88,8 +101,11 @@ deployment_graphs = { ...@@ -88,8 +101,11 @@ deployment_graphs = {
module="graphs.agg:Frontend", module="graphs.agg:Frontend",
config="configs/agg.yaml", config="configs/agg.yaml",
directory="/workspace/examples/sglang", directory="/workspace/examples/sglang",
endpoint="v1/chat/completions", endpoints=["v1/chat/completions", "v1/completions"],
response_handler=completions_response_handler, response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_1, pytest.mark.sglang], marks=[pytest.mark.gpu_1, pytest.mark.sglang],
), ),
text_payload, text_payload,
...@@ -99,8 +115,11 @@ deployment_graphs = { ...@@ -99,8 +115,11 @@ deployment_graphs = {
module="graphs.disagg:Frontend", module="graphs.disagg:Frontend",
config="configs/disagg.yaml", config="configs/disagg.yaml",
directory="/workspace/examples/llm", directory="/workspace/examples/llm",
endpoint="v1/chat/completions", endpoints=["v1/chat/completions", "v1/completions"],
response_handler=completions_response_handler, response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_2, pytest.mark.vllm], marks=[pytest.mark.gpu_2, pytest.mark.vllm],
), ),
text_payload, text_payload,
...@@ -110,8 +129,11 @@ deployment_graphs = { ...@@ -110,8 +129,11 @@ deployment_graphs = {
module="graphs.agg_router:Frontend", module="graphs.agg_router:Frontend",
config="configs/agg_router.yaml", config="configs/agg_router.yaml",
directory="/workspace/examples/llm", directory="/workspace/examples/llm",
endpoint="v1/chat/completions", endpoints=["v1/chat/completions", "v1/completions"],
response_handler=completions_response_handler, response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_1, pytest.mark.vllm], marks=[pytest.mark.gpu_1, pytest.mark.vllm],
), ),
text_payload, text_payload,
...@@ -121,8 +143,11 @@ deployment_graphs = { ...@@ -121,8 +143,11 @@ deployment_graphs = {
module="graphs.disagg_router:Frontend", module="graphs.disagg_router:Frontend",
config="configs/disagg_router.yaml", config="configs/disagg_router.yaml",
directory="/workspace/examples/llm", directory="/workspace/examples/llm",
endpoint="v1/chat/completions", endpoints=["v1/chat/completions", "v1/completions"],
response_handler=completions_response_handler, response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_2, pytest.mark.vllm], marks=[pytest.mark.gpu_2, pytest.mark.vllm],
), ),
text_payload, text_payload,
...@@ -132,8 +157,11 @@ deployment_graphs = { ...@@ -132,8 +157,11 @@ deployment_graphs = {
module="graphs.agg:Frontend", module="graphs.agg:Frontend",
config="configs/agg.yaml", config="configs/agg.yaml",
directory="/workspace/examples/multimodal", directory="/workspace/examples/multimodal",
endpoint="v1/chat/completions", endpoints=["v1/chat/completions", "v1/completions"],
response_handler=completions_response_handler, response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_2, pytest.mark.vllm], marks=[pytest.mark.gpu_2, pytest.mark.vllm],
), ),
multimodal_payload, multimodal_payload,
...@@ -143,12 +171,79 @@ deployment_graphs = { ...@@ -143,12 +171,79 @@ deployment_graphs = {
module="graphs.agg:Frontend", module="graphs.agg:Frontend",
config="configs/agg.yaml", config="configs/agg.yaml",
directory="/workspace/examples/vllm_v1", directory="/workspace/examples/vllm_v1",
endpoint="v1/chat/completions", endpoints=["v1/chat/completions", "v1/completions"],
response_handler=completions_response_handler, response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_1, pytest.mark.vllm], marks=[pytest.mark.gpu_1, pytest.mark.vllm],
), ),
text_payload, text_payload,
), ),
"trtllm_agg": (
DeploymentGraph(
module="graphs.agg:Frontend",
config="configs/agg.yaml",
directory="/workspace/examples/tensorrt_llm",
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_1, pytest.mark.tensorrtllm],
),
text_payload,
),
"trtllm_agg_router": (
DeploymentGraph(
module="graphs.agg_router:Frontend",
config="configs/agg_router.yaml",
directory="/workspace/examples/tensorrt_llm",
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_1, pytest.mark.tensorrtllm],
# FIXME: This is a hack to allow deployments to start before sending any requests.
# When using KV-router, if all the endpoints are not registered, the service
# enters a non-recoverable state.
delayed_start=60,
),
text_payload,
),
"trtllm_disagg": (
DeploymentGraph(
module="graphs.disagg:Frontend",
config="configs/disagg.yaml",
directory="/workspace/examples/tensorrt_llm",
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_2, pytest.mark.tensorrtllm],
),
text_payload,
),
"trtllm_disagg_router": (
DeploymentGraph(
module="graphs.disagg_router:Frontend",
config="configs/disagg_router.yaml",
directory="/workspace/examples/tensorrt_llm",
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
marks=[pytest.mark.gpu_2, pytest.mark.tensorrtllm],
# FIXME: This is a hack to allow deployments to start before sending any requests.
# When using KV-router, if all the endpoints are not registered, the service
# enters a non-recoverable state.
delayed_start=120,
),
text_payload,
),
} }
...@@ -175,6 +270,7 @@ class DynamoServeProcess(ManagedProcess): ...@@ -175,6 +270,7 @@ class DynamoServeProcess(ManagedProcess):
working_dir=graph.directory, working_dir=graph.directory,
health_check_ports=[port], health_check_ports=[port],
health_check_urls=health_check_urls, health_check_urls=health_check_urls,
delayed_start=graph.delayed_start,
stragglers=["http"], stragglers=["http"],
log_dir=request.node.name, log_dir=request.node.name,
) )
...@@ -196,6 +292,16 @@ class DynamoServeProcess(ManagedProcess): ...@@ -196,6 +292,16 @@ class DynamoServeProcess(ManagedProcess):
pytest.param("disagg", marks=[pytest.mark.vllm, pytest.mark.gpu_2]), pytest.param("disagg", marks=[pytest.mark.vllm, pytest.mark.gpu_2]),
pytest.param("disagg_router", marks=[pytest.mark.vllm, pytest.mark.gpu_2]), pytest.param("disagg_router", marks=[pytest.mark.vllm, pytest.mark.gpu_2]),
pytest.param("multimodal_agg", marks=[pytest.mark.vllm, pytest.mark.gpu_2]), pytest.param("multimodal_agg", marks=[pytest.mark.vllm, pytest.mark.gpu_2]),
pytest.param("trtllm_agg", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_1]),
pytest.param(
"trtllm_agg_router", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_1]
),
pytest.param(
"trtllm_disagg", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_2]
),
pytest.param(
"trtllm_disagg_router", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_2]
),
# pytest.param("sglang", marks=[pytest.mark.sglang, pytest.mark.gpu_2]), # pytest.param("sglang", marks=[pytest.mark.sglang, pytest.mark.gpu_2]),
] ]
) )
...@@ -220,65 +326,89 @@ def test_serve_deployment(deployment_graph_test, request, runtime_services): ...@@ -220,65 +326,89 @@ def test_serve_deployment(deployment_graph_test, request, runtime_services):
deployment_graph, payload = deployment_graph_test deployment_graph, payload = deployment_graph_test
with DynamoServeProcess(deployment_graph, request) as server_process: def check_response(response, response_handler):
url = f"http://localhost:{server_process.port}/{deployment_graph.endpoint}" assert response.status_code == 200, "Server is not healthy"
start_time = time.time() content = response_handler(response)
retry_delay = 5
elapsed = 0.0
while time.time() - start_time < deployment_graph.timeout:
elapsed = time.time() - start_time
try:
response = requests.post(
url,
json=payload.payload,
timeout=deployment_graph.timeout - elapsed,
)
except (requests.RequestException, requests.Timeout) as e:
logger.warning("Retrying due to Request failed: %s", e)
time.sleep(retry_delay)
continue
logger.info("Response%r", response)
if response.status_code == 500:
error = response.json().get("error", "")
if "no instances" in error:
logger.warning("Retrying due to no instances available")
time.sleep(retry_delay)
continue
if response.status_code == 404:
error = response.json().get("error", "")
if "Model not found" in error:
logger.warning("Retrying due to model not found")
time.sleep(retry_delay)
continue
# Process the response
if response.status_code != 200:
logger.error(
"Service returned status code %s: %s",
response.status_code,
response.text,
)
pytest.fail(
"Service returned status code %s: %s"
% (response.status_code, response.text)
)
else:
break
else:
logger.error(
"Service did not return a successful response within %s s",
deployment_graph.timeout,
)
pytest.fail(
"Service did not return a successful response within %s s"
% deployment_graph.timeout
)
content = deployment_graph.response_handler(response)
logger.info("Received Content: %s", content) logger.info("Received Content: %s", content)
# Check for expected responses # Check for expected responses
assert content, "Empty response content" assert content, "Empty response content"
for expected in payload.expected_response: for expected in payload.expected_response:
assert expected in content, "Expected '%s' not found in response" % expected assert expected in content, "Expected '%s' not found in response" % expected
with DynamoServeProcess(deployment_graph, request) as server_process:
first_success_pending = True
for endpoint, response_handler in zip(
deployment_graph.endpoints, deployment_graph.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time()
retry_delay = 5
elapsed = 0.0
request_body = (
payload.payload_chat
if endpoint == "v1/chat/completions"
else payload.payload_completions
)
# We can skip this
while (
time.time() - start_time < deployment_graph.timeout
and first_success_pending
):
elapsed = time.time() - start_time
try:
response = requests.post(
url,
json=request_body,
timeout=deployment_graph.timeout - elapsed,
)
except (requests.RequestException, requests.Timeout) as e:
logger.warning("Retrying due to Request failed: %s", e)
time.sleep(retry_delay)
continue
logger.info("Response%r", response)
if response.status_code == 500:
error = response.json().get("error", "")
if "no instances" in error:
logger.warning("Retrying due to no instances available")
time.sleep(retry_delay)
continue
if response.status_code == 404:
error = response.json().get("error", "")
if "Model not found" in error:
logger.warning("Retrying due to model not found")
time.sleep(retry_delay)
continue
# Process the response
if response.status_code != 200:
logger.error(
"Service returned status code %s: %s",
response.status_code,
response.text,
)
pytest.fail(
"Service returned status code %s: %s"
% (response.status_code, response.text)
)
else:
check_response(response, response_handler)
first_success_pending = False
break
else:
if first_success_pending:
logger.error(
"Service did not return a successful response within %s s",
deployment_graph.timeout,
)
pytest.fail(
"Service did not return a successful response within %s s"
% deployment_graph.timeout
)
for _ in range(payload.repeat_count):
response = requests.post(
url,
json=request_body,
timeout=deployment_graph.timeout - elapsed,
)
check_response(response, response_handler)
...@@ -26,9 +26,10 @@ class DeploymentGraph: ...@@ -26,9 +26,10 @@ class DeploymentGraph:
module: str module: str
config: str config: str
directory: str directory: str
endpoint: str endpoints: List[str]
response_handler: Callable[[Any], str] response_handlers: List[Callable[[Any], str]]
timeout: int = 900 timeout: int = 900
delayed_start: int = 0
marks: Optional[List[Any]] = field(default_factory=list) marks: Optional[List[Any]] = field(default_factory=list)
...@@ -38,12 +39,14 @@ class Payload: ...@@ -38,12 +39,14 @@ class Payload:
Represents a test payload with expected response and log patterns. Represents a test payload with expected response and log patterns.
""" """
payload: Dict[str, Any] payload_chat: Dict[str, Any]
expected_response: List[str] expected_response: List[str]
expected_log: List[str] expected_log: List[str]
repeat_count: int = 1
payload_completions: Optional[Dict[str, Any]] = None
def completions_response_handler(response): def chat_completions_response_handler(response):
""" """
Process chat completions API responses. Process chat completions API responses.
""" """
...@@ -55,3 +58,16 @@ def completions_response_handler(response): ...@@ -55,3 +58,16 @@ def completions_response_handler(response):
assert "message" in result["choices"][0], "Missing 'message' in first choice" assert "message" in result["choices"][0], "Missing 'message' in first choice"
assert "content" in result["choices"][0]["message"], "Missing 'content' in message" assert "content" in result["choices"][0]["message"], "Missing 'content' in message"
return result["choices"][0]["message"]["content"] return result["choices"][0]["message"]["content"]
def completions_response_handler(response):
"""
Process completions API responses.
"""
if response.status_code != 200:
return ""
result = response.json()
assert "choices" in result, "Missing 'choices' in response"
assert len(result["choices"]) > 0, "Empty choices in response"
assert "text" in result["choices"][0], "Missing 'text' in first choice"
return result["choices"][0]["text"]
...@@ -32,6 +32,7 @@ class ManagedProcess: ...@@ -32,6 +32,7 @@ class ManagedProcess:
env: Optional[dict] = None env: Optional[dict] = None
health_check_ports: List[int] = field(default_factory=list) health_check_ports: List[int] = field(default_factory=list)
health_check_urls: List[Any] = field(default_factory=list) health_check_urls: List[Any] = field(default_factory=list)
delayed_start: int = 0
timeout: int = 300 timeout: int = 300
working_dir: Optional[str] = None working_dir: Optional[str] = None
display_output: bool = False display_output: bool = False
...@@ -59,6 +60,7 @@ class ManagedProcess: ...@@ -59,6 +60,7 @@ class ManagedProcess:
self._terminate_existing() self._terminate_existing()
self._start_process() self._start_process()
time.sleep(self.delayed_start)
elapsed = self._check_ports(self.timeout) elapsed = self._check_ports(self.timeout)
self._check_urls(self.timeout - elapsed) self._check_urls(self.timeout - elapsed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment