Unverified Commit 9bf9709a authored by Krishnan Prashanth's avatar Krishnan Prashanth Committed by GitHub
Browse files

refactor(tests): consolidate diffusion payload classes into shared payloads.py (#8435)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent d932d3b4
...@@ -5,7 +5,6 @@ import dataclasses ...@@ -5,7 +5,6 @@ import dataclasses
import logging import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any
import pytest import pytest
...@@ -26,40 +25,11 @@ from tests.utils.payload_builder import ( ...@@ -26,40 +25,11 @@ from tests.utils.payload_builder import (
metric_payload_default, metric_payload_default,
multimodal_payload_default, multimodal_payload_default,
) )
from tests.utils.payloads import BasePayload from tests.utils.payloads import VideoGenerationPayload
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class VideoGenerationPayload(BasePayload):
"""Payload for /v1/videos endpoint (TRT-LLM video diffusion)."""
endpoint: str = "/v1/videos"
timeout: int = 300
def response_handler(self, response: Any) -> str:
response.raise_for_status()
result = response.json()
assert result.get("status") == "completed", (
f"Video generation not completed. Status: {result.get('status')}, "
f"Error: {result.get('error', 'none')}"
)
assert (
"data" in result
), f"Missing 'data' in response. Keys: {list(result.keys())}"
assert len(result["data"]) > 0, "Empty data in video response"
entry = result["data"][0]
if "url" in entry:
assert entry["url"], "Video response url is empty"
return entry["url"]
assert entry.get("b64_json"), "Video response b64_json is empty"
return "b64_video_returned"
def validate(self, response: Any, content: str) -> None:
assert content, "Video response content is empty"
@dataclass @dataclass
class TRTLLMConfig(EngineConfig): class TRTLLMConfig(EngineConfig):
"""Configuration for trtllm test scenarios""" """Configuration for trtllm test scenarios"""
...@@ -437,6 +407,7 @@ trtllm_configs = { ...@@ -437,6 +407,7 @@ trtllm_configs = {
"seed": 42, "seed": 42,
}, },
}, },
timeout=300,
repeat_count=1, repeat_count=1,
expected_response=[], expected_response=[],
expected_log=[], expected_log=[],
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import base64
import dataclasses import dataclasses
import logging import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from io import BytesIO
from typing import Any
import pytest import pytest
...@@ -22,7 +19,13 @@ from tests.serve.common import ( ...@@ -22,7 +19,13 @@ from tests.serve.common import (
run_serve_deployment, run_serve_deployment,
) )
from tests.utils.engine_process import EngineConfig from tests.utils.engine_process import EngineConfig
from tests.utils.payloads import BasePayload, ChatPayload from tests.utils.payloads import (
AudioSpeechPayload,
ChatPayload,
I2VPayload,
ImageGenerationPayload,
VideoGenerationPayload,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -31,110 +34,6 @@ vllm_dir = os.environ.get("VLLM_DIR") or os.path.join( ...@@ -31,110 +34,6 @@ vllm_dir = os.environ.get("VLLM_DIR") or os.path.join(
) )
@dataclass
class ImageGenerationPayload(BasePayload):
"""Payload for /v1/images/generations endpoint."""
endpoint: str = "/v1/images/generations"
timeout: int = 300
def response_handler(self, response: Any) -> str:
response.raise_for_status()
result = response.json()
assert (
"data" in result
), f"Missing 'data' in response. Keys: {list(result.keys())}"
assert len(result["data"]) > 0, "Empty data in image response"
entry = result["data"][0]
if "url" in entry:
assert entry["url"], "Image response url is empty"
return entry["url"]
assert entry.get("b64_json"), "Image response b64_json is empty"
return "b64_image_returned"
@dataclass
class VideoGenerationPayload(BasePayload):
"""Payload for /v1/videos endpoint."""
endpoint: str = "/v1/videos"
timeout: int = 600
def response_handler(self, response: Any) -> str:
response.raise_for_status()
result = response.json()
assert result.get("status") == "completed", (
f"Video generation not completed. Status: {result.get('status')}, "
f"Error: {result.get('error', 'none')}"
)
assert (
"data" in result
), f"Missing 'data' in response. Keys: {list(result.keys())}"
assert len(result["data"]) > 0, "Empty data in video response"
entry = result["data"][0]
if "url" in entry:
assert entry["url"], "Video response url is empty"
return entry["url"]
assert entry.get("b64_json"), "Video response b64_json is empty"
return "b64_video_returned"
def validate(self, response: Any, content: str) -> None:
assert content, "Video response content is empty"
if self.expected_response and not any(
expected.lower() in content.lower() for expected in self.expected_response
):
raise AssertionError(
f"Expected at least one of {self.expected_response} in {content!r}"
)
@dataclass
class I2VPayload(VideoGenerationPayload):
"""Payload for image-to-video via /v1/videos with input_reference."""
def __post_init__(self):
from PIL import Image
image_buffer = BytesIO()
Image.new("RGB", (64, 64), color="red").save(image_buffer, format="PNG")
image_b64 = base64.b64encode(image_buffer.getvalue()).decode("ascii")
self.body["input_reference"] = f"data:image/png;base64,{image_b64}"
@dataclass
class AudioSpeechPayload(BasePayload):
"""Payload for /v1/audio/speech endpoint."""
endpoint: str = "/v1/audio/speech"
timeout: int = 300
def response_handler(self, response: Any) -> str:
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "audio" in content_type:
# Binary audio response
audio_bytes = response.content
assert len(audio_bytes) > 100, (
f"Audio response too small ({len(audio_bytes)} bytes), "
f"likely not valid audio"
)
return f"binary_audio_{len(audio_bytes)}_bytes"
# JSON response (error or url format)
result = response.json()
assert (
result.get("status") != "failed"
), f"Audio generation failed: {result.get('error', 'unknown')}"
assert (
"data" in result
), f"Missing 'data' in response. Keys: {list(result.keys())}"
assert len(result["data"]) > 0, "Empty data in audio response"
entry = result["data"][0]
if "url" in entry and entry["url"]:
return entry["url"]
assert entry.get("b64_json"), "Audio response b64_json is empty"
return "b64_audio_returned"
@dataclass @dataclass
class VLLMOmniConfig(EngineConfig): class VLLMOmniConfig(EngineConfig):
"""Configuration for vLLM-Omni test scenarios.""" """Configuration for vLLM-Omni test scenarios."""
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64
import logging import logging
import math import math
import re import re
import time import time
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from io import BytesIO
from typing import Any, Callable, Dict, List, Optional, cast from typing import Any, Callable, Dict, List, Optional, cast
import requests import requests
...@@ -1365,3 +1367,57 @@ class VideoGenerationPayload(BasePayload): ...@@ -1365,3 +1367,57 @@ class VideoGenerationPayload(BasePayload):
return entry["url"] return entry["url"]
assert entry.get("b64_json"), "Video response b64_json is empty" assert entry.get("b64_json"), "Video response b64_json is empty"
return "b64_video_returned" return "b64_video_returned"
def validate(self, response: Any, content: str) -> None:
assert content, "Video response content is empty"
if self.expected_response and not any(
expected.lower() in content.lower() for expected in self.expected_response
):
raise AssertionError(
f"Expected at least one of {self.expected_response} in {content!r}"
)
@dataclass
class I2VPayload(VideoGenerationPayload):
"""Payload for image-to-video via /v1/videos with input_reference."""
def __post_init__(self):
from PIL import Image
image_buffer = BytesIO()
Image.new("RGB", (64, 64), color="red").save(image_buffer, format="PNG")
image_b64 = base64.b64encode(image_buffer.getvalue()).decode("ascii")
self.body["input_reference"] = f"data:image/png;base64,{image_b64}"
@dataclass
class AudioSpeechPayload(BasePayload):
"""Payload for /v1/audio/speech endpoint."""
endpoint: str = "/v1/audio/speech"
timeout: int = 300
def response_handler(self, response: Any) -> str:
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "audio" in content_type:
audio_bytes = response.content
assert len(audio_bytes) > 100, (
f"Audio response too small ({len(audio_bytes)} bytes), "
f"likely not valid audio"
)
return f"binary_audio_{len(audio_bytes)}_bytes"
result = response.json()
assert (
result.get("status") != "failed"
), f"Audio generation failed: {result.get('error', 'unknown')}"
assert (
"data" in result
), f"Missing 'data' in response. Keys: {list(result.keys())}"
assert len(result["data"]) > 0, "Empty data in audio response"
entry = result["data"][0]
if "url" in entry and entry["url"]:
return entry["url"]
assert entry.get("b64_json"), "Audio response b64_json is empty"
return "b64_audio_returned"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment