test_t2i_model.py 2.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import sys
from pathlib import Path

import pytest
import torch

from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform

# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from vllm_omni import Omni

os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"


models = ["Tongyi-MAI/Z-Image-Turbo", "riverclouds/qwen_image_random"]

# Modelscope can't find riverclouds/qwen_image_random
# TODO: When NPU support is ready, remove this branch.
if current_omni_platform.is_npu():
    models = ["Tongyi-MAI/Z-Image-Turbo", "Qwen/Qwen-Image"]
elif current_omni_platform.is_rocm():
    # TODO: When ROCm support is ready, remove this branch.
    # vLLM V0.11.0 has issues running riverclouds/qwen_image_random
    # on ROCm
    models = ["Tongyi-MAI/Z-Image-Turbo"]


@pytest.mark.parametrize("model_name", models)
def test_diffusion_model(model_name: str):
    m = None
    try:
        m = Omni(model=model_name)
        # high resolution may cause OOM on L4
        height = 256
        width = 256
        outputs = m.generate(
            "a photo of a cat sitting on a laptop keyboard",
            OmniDiffusionSamplingParams(
                height=height,
                width=width,
                num_inference_steps=2,
                guidance_scale=0.0,
                generator=torch.Generator("cuda").manual_seed(42),
                num_outputs_per_prompt=2,
            ),
        )
        # Extract images from request_output[0]['images']
        first_output = outputs[0]
        assert first_output.final_output_type == "image"
        if not hasattr(first_output, "request_output") or not first_output.request_output:
            raise ValueError("No request_output found in OmniRequestOutput")

        req_out = first_output.request_output[0]
        if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
            raise ValueError("Invalid request_output structure or missing 'images' key")

        images = req_out.images

        assert len(images) == 2
        # check image size
        assert images[0].width == width
        assert images[0].height == height
        images[0].save("image_output.png")
    except Exception as e:
        print(f"Test failed with error: {e}")
        raise
    finally:
        if m is not None and hasattr(m, "close"):
            m.close()