test_diffusion_cpu_offload.py 2.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import sys
from pathlib import Path

import pytest
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

from tests.utils import GPUMemoryMonitor
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform

# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from vllm_omni import Omni

models = ["riverclouds/qwen_image_random"]


def inference(model_name: str, offload: bool = True):
    current_omni_platform.empty_cache()
    device_index = torch.cuda.current_device()
    monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
    monitor.start()
    m = Omni(model=model_name, enable_cpu_offload=offload)
    torch.cuda.reset_peak_memory_stats(device=device_index)
    height = 256
    width = 256

    m.generate(
        "a photo of a cat sitting on a laptop keyboard",
        OmniDiffusionSamplingParams(
            height=height,
            width=width,
            num_inference_steps=9,
            guidance_scale=0.0,
            generator=torch.Generator("cuda").manual_seed(42),
        ),
    )
    peak = monitor.peak_used_mb
    monitor.stop()

    return peak


@pytest.mark.skipif(current_omni_platform.is_npu() or current_omni_platform.is_rocm(), reason="Hardware not supported")
@pytest.mark.parametrize("model_name", models)
def test_cpu_offload_diffusion_model(model_name: str):
    try:
        no_offload_peak_memory = inference(model_name, offload=False)
        cleanup_dist_env_and_memory()
        offload_peak_memory = inference(model_name, offload=True)
    except Exception:
        pytest.fail("Inference failed")
    print(f"Offload peak memory: {offload_peak_memory} MB")
    print(f"No offload peak memory: {no_offload_peak_memory} MB")
    assert offload_peak_memory + 2500 < no_offload_peak_memory, (
        f"Offload peak memory {offload_peak_memory} MB should be less than no offload peak memory {no_offload_peak_memory} MB"
    )