"requirements/nccl_wz/topo-input.xml" did not exist on "f098f25042dcce18b3b669a36d9a42dd11f05823"
test_sdxl.py 5.11 KB
Newer Older
dengdong's avatar
dengdong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gc
import os
from pathlib import Path

import pytest
import torch
from diffusers import StableDiffusionXLPipeline

from nunchaku.models.unets.unet_sdxl import NunchakuSDXLUNet2DConditionModel
from nunchaku.utils import get_precision, is_turing

from ...flux.utils import already_generate, compute_lpips, hash_str_to_int
from .test_sdxl_turbo import plot, run_benchmark


@pytest.mark.skipif(
    is_turing() or get_precision() == "fp4", reason="Skip tests due to using Turing GPUs or FP4 precision"
)
@pytest.mark.parametrize("expected_lpips", [0.25 if get_precision() == "int4" else 0.18])
def test_sdxl_lpips(expected_lpips: float):
    gc.collect()
    torch.cuda.empty_cache()

    precision = get_precision()

    ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
    results_dir_original = ref_root / "fp16" / "sdxl"
    results_dir_nunchaku = ref_root / precision / "sdxl"

    os.makedirs(results_dir_original, exist_ok=True)
    os.makedirs(results_dir_nunchaku, exist_ok=True)

    prompts = [
        "Ilya Repin, Moebius, Yoshitaka Amano, 1980s nubian punk rock glam core fashion shoot, closeup, 35mm ",
        "A honeybee sitting on a flower in a garden full of yellow flowers",
        "Vibrant, tropical rainforest, teeming with wildlife, nature photography ",
        "very realistic photo of barak obama in a wing eating contest",
        "oil paint of colorful wildflowers in a meadow, Paul Signac divisionism style ",
    ]

    if not already_generate(results_dir_original, 5):
        pipeline = StableDiffusionXLPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, use_safetensors=True, variant="fp16"
        ).to("cuda")

        for prompt in prompts:
            seed = hash_str_to_int(prompt)
            result = pipeline(
                prompt=prompt, guidance_scale=5.0, num_inference_steps=50, generator=torch.Generator().manual_seed(seed)
            ).images[0]
            result.save(os.path.join(results_dir_original, f"{seed}.png"))

        del pipeline.unet
        del pipeline.text_encoder
        del pipeline.text_encoder_2
        del pipeline.vae
        del pipeline
        del result
        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

    free, total = torch.cuda.mem_get_info()
    print(f"After original generation: Free: {free/1024**2:.0f} MB  /  Total: {total/1024**2:.0f} MB")

    if not already_generate(results_dir_nunchaku, 5):
        quantized_unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
            "nunchaku-tech/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors"
        )
        pipeline = StableDiffusionXLPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            unet=quantized_unet,
            torch_dtype=torch.bfloat16,
            use_safetensors=True,
            variant="fp16",
        )
        pipeline.unet = quantized_unet
        pipeline = pipeline.to("cuda")
        for prompt in prompts:
            seed = hash_str_to_int(prompt)
            result = pipeline(
                prompt=prompt, guidance_scale=5.0, num_inference_steps=50, generator=torch.Generator().manual_seed(seed)
            ).images[0]
            result.save(os.path.join(results_dir_nunchaku, f"{seed}.png"))

        del pipeline
        del quantized_unet
        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

    free, total = torch.cuda.mem_get_info()
    print(f"After Nunchaku generation: Free: {free/1024**2:.0f} MB  /  Total: {total/1024**2:.0f} MB")

    lpips = compute_lpips(results_dir_original, results_dir_nunchaku)
    print(f"lpips: {lpips}")
    assert lpips < expected_lpips * 1.15


@pytest.mark.skipif(
    is_turing() or get_precision() == "fp4", reason="Skip tests due to using Turing GPUs or FP4 precision"
)
@pytest.mark.parametrize("expected_latency", [7.455])
def test_sdxl_time_cost(expected_latency: float):
    batch_size = 2
    runs = 5
    inference_steps = 50
    guidance_scale = 5.0
    device_name = torch.cuda.get_device_name(0)
    results = {"Nunchaku INT4": []}

    quantized_unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
        "nunchaku-tech/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors"
    )
    pipeline_quantized = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        unet=quantized_unet,
        torch_dtype=torch.bfloat16,
        use_safetensors=True,
        variant="fp16",
    )

    pipeline_quantized = pipeline_quantized.to("cuda")

    benchmark_quantized = run_benchmark(
        pipeline_quantized, batch_size, guidance_scale, device_name, runs, inference_steps
    )
    avg_latency = benchmark_quantized.mean() * inference_steps
    results["Nunchaku INT4"].append(avg_latency)

    ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
    plot_save_path = ref_root / "time_cost" / "sdxl"
    os.makedirs(plot_save_path, exist_ok=True)

    plot([batch_size], results, device_name, runs, inference_steps, plot_save_path, "SDXL")

    assert avg_latency < expected_latency * 1.1