test_bagel_text2img.py 9.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
End-to-end test for Bagel text2img generation.

This test validates that the Bagel model generates images that match
expected reference pixel values within a ±5 tolerance.

Equivalent to running:
    python3 examples/offline_inference/bagel/end2end.py \
        --prompts "A futuristic city skyline at twilight, cyberpunk style" \
        --modality text2img --step 15
"""

import os

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"

import signal
import socket
import subprocess
import tempfile
import time
from pathlib import Path
from typing import Any

import pytest
from PIL import Image

from tests.utils import hardware_test
from vllm_omni.entrypoints.omni import Omni

# Reference pixel data extracted from the known-good output image
# Each entry contains (x, y) position and expected (R, G, B) values
# "Generated with seed=52, num_inference_steps=15,
# prompt='A futuristic city skyline at twilight, cyberpunk style'"
REFERENCE_PIXELS = [
    {"position": (100, 100), "rgb": (68, 107, 134)},
    {"position": (400, 50), "rgb": (95, 139, 166)},
    {"position": (700, 100), "rgb": (99, 122, 151)},
    {"position": (150, 400), "rgb": (111, 125, 153)},
    {"position": (512, 512), "rgb": (97, 107, 131)},
    {"position": (700, 400), "rgb": (48, 64, 98)},
    {"position": (100, 700), "rgb": (79, 63, 84)},
    {"position": (400, 700), "rgb": (40, 58, 79)},
    {"position": (700, 700), "rgb": (60, 75, 103)},
    {"position": (256, 256), "rgb": (97, 128, 156)},
]

# Maximum allowed difference per color channel
PIXEL_TOLERANCE = 5

# Default test prompt
DEFAULT_PROMPT = "<|im_start|>A futuristic city skyline at twilight, cyberpunk style<|im_end|>"


def _find_free_port() -> int:
    """Find and return a free ephemeral port by binding to port 0."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        s.listen(1)
        port = s.getsockname()[1]
    return port


def _configure_sampling_params(omni: Omni, max_tokens: int = 1, num_inference_steps: int = 15) -> list:
    """Configure sampling parameters for Bagel text2img generation.

    Args:
        omni: The Omni instance to get default params from.
        max_tokens: Maximum tokens for the first stage.
        num_inference_steps: Number of inference steps for the diffusion stage.

    Returns:
        Configured sampling params list.
    """
    params_list = omni.default_sampling_params_list
    params_list[0].max_tokens = max_tokens  # type: ignore
    if len(params_list) > 1:
        params_list[1].num_inference_steps = num_inference_steps  # type: ignore
    return params_list


def _extract_generated_image(omni_outputs: list) -> Image.Image | None:
    """Extract the generated image from Omni outputs.

    Args:
        omni_outputs: List of outputs from omni.generate().

    Returns:
        The first generated PIL Image, or None if no image found.
    """
    for req_output in omni_outputs:
        if images := getattr(req_output, "images", None):
            return images[0]
        if hasattr(req_output, "request_output") and req_output.request_output:
            for stage_out in req_output.request_output:
                if hasattr(stage_out, "images") and stage_out.images:
                    return stage_out.images[0]
    return None


def _validate_pixels(
    image: Image.Image,
    reference_pixels: list[dict[str, Any]] = REFERENCE_PIXELS,
    tolerance: int = PIXEL_TOLERANCE,
) -> None:
    """Validate that image pixels match expected reference values.

    Args:
        image: The PIL Image to validate.
        reference_pixels: List of dicts with 'position' (x, y) and 'rgb' (R, G, B).
        tolerance: Maximum allowed difference per color channel.

    Raises:
        AssertionError: If any pixel differs beyond tolerance.
    """
    for ref in reference_pixels:
        x, y = ref["position"]
        expected = ref["rgb"]
        actual = image.getpixel((x, y))[:3]
        assert all(abs(a - e) <= tolerance for a, e in zip(actual, expected)), (
            f"Pixel mismatch at ({x}, {y}): expected {expected}, got {actual}"
        )


def _generate_bagel_image(omni: Omni, prompt: str = DEFAULT_PROMPT) -> Image.Image:
    """Generate an image using Bagel model with configured parameters.

    Args:
        omni: The Omni instance to use for generation.
        prompt: The text prompt for image generation.

    Returns:
        The generated PIL Image.

    Raises:
        AssertionError: If no image is generated or size is incorrect.
    """
    params_list = _configure_sampling_params(omni)

    omni_outputs = list(
        omni.generate(
            prompts=[{"prompt": prompt, "modalities": ["image"]}],
            sampling_params_list=params_list,
        )
    )

    generated_image = _extract_generated_image(omni_outputs)
    assert generated_image is not None, "No images generated"
    assert generated_image.size == (1024, 1024), f"Expected 1024x1024, got {generated_image.size}"

    return generated_image


@pytest.mark.core_model
@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100"})
def test_bagel_text2img_shared_memory_connector():
    """Test Bagel text2img with shared memory connector."""
    config_path = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")
    omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=config_path, stage_init_timeout=300)

    try:
        generated_image = _generate_bagel_image(omni)
        _validate_pixels(generated_image)
    finally:
        omni.close()


def _wait_for_port(host: str, port: int, timeout: int = 30) -> bool:
    """Wait for a port to become available.

    Args:
        host: The host address.
        port: The port number.
        timeout: Maximum seconds to wait.

    Returns:
        True if port becomes available, False otherwise.
    """
    for _ in range(timeout):
        try:
            with socket.create_connection((host, port), timeout=1):
                return True
        except (TimeoutError, ConnectionRefusedError):
            time.sleep(1)
    return False


def _cleanup_mooncake_processes(timeout_secs: int = 5) -> None:
    """Clean up any existing mooncake_master processes.

    Args:
        timeout_secs: Maximum seconds to wait for graceful termination.
    """
    subprocess.run(
        ["pkill", "-f", "mooncake_master"],
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
    )
    start_time = time.time()
    while time.time() - start_time < timeout_secs:
        result = subprocess.run(
            ["pgrep", "-f", "mooncake_master"],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
        )
        if result.returncode != 0:
            break
        time.sleep(0.5)
    else:
        subprocess.run(
            ["pkill", "-9", "-f", "mooncake_master"],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
        )
    time.sleep(1)


def _load_mooncake_config(host: str, rpc_port: int, http_port: int) -> str:
    """Load Mooncake config from YAML and substitute placeholders.

    Args:
        host: Mooncake host address.
        rpc_port: RPC port for Mooncake master.
        http_port: HTTP metadata server port.

    Returns:
        Path to the temporary config file with substituted values.
    """
    config_path = str(Path(__file__).parent / "stage_configs" / "bagel_mooncake_ci.yaml")
    with open(config_path) as f:
        config_content = f.read()

    # Substitute placeholders
    config_content = config_content.replace("${MOONCAKE_HOST}", host)
    config_content = config_content.replace("${MOONCAKE_RPC_PORT}", str(rpc_port))
    config_content = config_content.replace("${MOONCAKE_HTTP_PORT}", str(http_port))

    # Write to temp file
    temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False)
    temp_file.write(config_content)
    temp_file.close()
    return temp_file.name


@pytest.mark.core_model
@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100"})
def test_bagel_text2img_mooncake_connector():
    """Test Bagel text2img with Mooncake connector for inter-stage communication."""
    MOONCAKE_HOST = "127.0.0.1"
    MOONCAKE_RPC_PORT = _find_free_port()
    MOONCAKE_HTTP_PORT = _find_free_port()
    MOONCAKE_METRICS_PORT = _find_free_port()

    mooncake_master_proc = None
    temp_config_file = None
    omni = None

    try:
        _cleanup_mooncake_processes()

        # Start mooncake_master
        mooncake_master_proc = subprocess.Popen(
            [
                "mooncake_master",
                f"--rpc_port={MOONCAKE_RPC_PORT}",
                "--enable_http_metadata_server=true",
                "--http_metadata_server_host=0.0.0.0",
                f"--http_metadata_server_port={MOONCAKE_HTTP_PORT}",
                f"--metrics_port={MOONCAKE_METRICS_PORT}",
            ],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            preexec_fn=os.setsid,
        )

        assert _wait_for_port(MOONCAKE_HOST, MOONCAKE_RPC_PORT), "mooncake_master failed to start"

        # Create temp config and initialize Omni
        temp_config_file = _load_mooncake_config(
            host=MOONCAKE_HOST,
            rpc_port=MOONCAKE_RPC_PORT,
            http_port=MOONCAKE_HTTP_PORT,
        )

        omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=temp_config_file, stage_init_timeout=300)

        generated_image = _generate_bagel_image(omni)
        _validate_pixels(generated_image)

    finally:
        if omni:
            omni.close()
        if temp_config_file:
            try:
                os.unlink(temp_config_file)
            except OSError:
                pass
        if mooncake_master_proc:
            try:
                os.killpg(os.getpgid(mooncake_master_proc.pid), signal.SIGKILL)
            except OSError:
                pass