test_qwenimage_lightning.py 8.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import gc
import math
import os
from pathlib import Path

import pytest
import torch
from diffusers import FlowMatchEulerDiscreteScheduler, QwenImagePipeline

from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
from nunchaku.utils import get_gpu_memory, get_precision, is_turing

from ...utils import already_generate, compute_lpips
from ..utils import run_pipeline

precision = get_precision()
torch_dtype = torch.float16 if is_turing() else torch.bfloat16
dtype_str = "fp16" if torch_dtype == torch.float16 else "bf16"


model_paths = {
    "qwen-image-lightningv1.0-4steps": "nunchaku-tech/nunchaku-qwen-image/svdq-{precision}_r{rank}-qwen-image-lightningv1.0-{num_inference_steps}steps.safetensors",
    "qwen-image-lightningv1.1-8steps": "nunchaku-tech/nunchaku-qwen-image/svdq-{precision}_r{rank}-qwen-image-lightningv1.1-{num_inference_steps}steps.safetensors",
}
lora_paths = {
    "qwen-image-lightningv1.0-4steps": (
        "lightx2v/Qwen-Image-Lightning",
        "Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors",
    ),
    "qwen-image-lightningv1.1-8steps": (
        "lightx2v/Qwen-Image-Lightning",
        "Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors",
    ),
}


class Case:

    def __init__(self, model_name: str, num_inference_steps: int, rank: int, expected_lpips: dict[str, float]):
        self.model_name = model_name
        self.num_inference_steps = num_inference_steps
        self.rank = rank
        self.expected_lpips = expected_lpips


@pytest.mark.parametrize(
    "case",
    [
        pytest.param(
            Case(
                model_name="qwen-image-lightningv1.0-4steps",
                num_inference_steps=4,
                rank=32,
                expected_lpips={"int4-bf16": 0.35, "fp4-bf16": 0.33},
            ),
            id="qwen-image-lightningv1.0-4steps-r32",
        ),
        pytest.param(
            Case(
                model_name="qwen-image-lightningv1.0-4steps",
                num_inference_steps=4,
                rank=128,
                expected_lpips={"int4-bf16": 0.32, "fp4-bf16": 0.32},
            ),
            id="qwen-image-lightningv1.0-4steps-r128",
        ),
        pytest.param(
            Case(
                model_name="qwen-image-lightningv1.1-8steps",
                num_inference_steps=8,
                rank=32,
                expected_lpips={"int4-bf16": 0.33, "fp4-bf16": 0.34},
            ),
            id="qwen-image-lightningv1.1-8steps-r32",
        ),
        pytest.param(
            Case(
                model_name="qwen-image-lightningv1.1-8steps",
                num_inference_steps=8,
                rank=128,
                expected_lpips={"int4-bf16": 0.31, "fp4-bf16": 0.32},
            ),
            id="qwen-image-lightningv1.1-8steps-r128",
        ),
    ],
)
def test_qwenimage_lightning(case: Case):
    batch_size = 1
    width = 1024
    height = 1024
    true_cfg_scale = 1.0
    rank = case.rank
    expected_lpips = case.expected_lpips[f"{precision}-{dtype_str}"]
    model_name = case.model_name
    num_inference_steps = case.num_inference_steps

    ref_root = os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref"))
    folder_name = f"w{width}h{height}t{num_inference_steps}g{true_cfg_scale}"
    save_dir_16bit = Path(ref_root) / model_name / dtype_str / folder_name

    repo_id = "Qwen/Qwen-Image"
    dataset = [
        {
            "prompt": """Bookstore window display. A sign displays “New Arrivals This Week”. Below, a shelf tag with the text “Best-Selling Novels Here”. To the side, a colorful poster advertises “Author Meet And Greet on Saturday” with a central portrait of the author. There are four books on the bookshelf, namely “The light between worlds” “When stars are scattered” “The slient patient” “The night circus” Ultra HD, 4K, cinematic composition.""",
            "filename": "bookstore",
        },
        {
            "prompt": "一副典雅庄重的对联悬挂于厅堂之中,房间是个安静古典的中式布置,桌子上放着一些青花瓷,对联上左书“义本生知人机同道善思新”,右书“通云赋智乾坤启数高志远”, 横批“智启通义”,字体飘逸,中间挂在一着一副中国风的画作,内容是岳阳楼。超清,4K,电影级构图",
            "filename": "chinese_room",
        },
        {
            "prompt": '一张企业级高质量PPT页面图像,整体采用科技感十足的星空蓝为主色调,背景融合流动的发光科技线条与微光粒子特效,营造出专业、现代且富有信任感的品牌氛围;页面顶部左侧清晰展示橘红色Alibaba标志,色彩鲜明、辨识度高。主标题位于画面中央偏上位置,使用大号加粗白色或浅蓝色字体写着“通义千问视觉基础模型”,字体现代简洁,突出技术感;主标题下方紧接一行楷体中文文字:“原生中文·复杂场景·自动布局”,字体柔和优雅,形成科技与人文的融合。下方居中排布展示了四张与图片,分别是:一幅写实与水墨风格结合的梅花特写,枝干苍劲、花瓣清雅,背景融入淡墨晕染与飘雪效果,体现坚韧不拔的精神气质;上方写着黑色的楷体"梅傲"。一株生长于山涧石缝中的兰花,叶片修长、花朵素净,搭配晨雾缭绕的自然环境,展现清逸脱俗的文人风骨;上方写着黑色的楷体"兰幽"。一组迎风而立的翠竹,竹叶随风摇曳,光影交错,背景为青灰色山岩与流水,呈现刚柔并济、虚怀若谷的文化意象;上方写着黑色的楷体"竹清"。一片盛开于秋日庭院的菊花丛,花色丰富、层次分明,配以落叶与古亭剪影,传递恬然自适的生活哲学;上方写着黑色的楷体"菊淡"。所有图片采用统一尺寸与边框样式,呈横向排列。页面底部中央用楷体小字写明“2025年8月,敬请期待”,排版工整、结构清晰,整体风格统一且细节丰富,极具视觉冲击力与品牌调性。',
            "filename": "ppt",
        },
    ]

    scheduler_config = {
        "base_image_seq_len": 256,
        "base_shift": math.log(3),  # We use shift=3 in distillation
        "invert_sigmas": False,
        "max_image_seq_len": 8192,
        "max_shift": math.log(3),  # We use shift=3 in distillation
        "num_train_timesteps": 1000,
        "shift": 1.0,
        "shift_terminal": None,  # set shift_terminal to None
        "stochastic_sampling": False,
        "time_shift_type": "exponential",
        "use_beta_sigmas": False,
        "use_dynamic_shifting": True,
        "use_exponential_sigmas": False,
        "use_karras_sigmas": False,
    }
    scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)

    if not already_generate(save_dir_16bit, len(dataset)):
        pipe = QwenImagePipeline.from_pretrained(repo_id, scheduler=scheduler, torch_dtype=torch_dtype)
        pipe.load_lora_weights(lora_paths[model_name][0], weight_name=lora_paths[model_name][1])
        pipe.fuse_lora()
        pipe.unload_lora_weights()
        pipe.enable_sequential_cpu_offload()
        run_pipeline(
            dataset=dataset,
            batch_size=1,
            pipeline=pipe,
            save_dir=save_dir_16bit,
            forward_kwargs={
                "width": width,
                "height": height,
                "num_inference_steps": num_inference_steps,
                "true_cfg_scale": true_cfg_scale,
            },
        )
        del pipe
        gc.collect()
        torch.cuda.empty_cache()

    save_dir_nunchaku = (
        Path("test_results")
        / "nunchaku"
        / model_name
        / f"{precision}_r{rank}-{dtype_str}"
        / f"{folder_name}-bs{batch_size}"
    )

    model_path = model_paths[model_name].format(precision=precision, rank=rank, num_inference_steps=num_inference_steps)
    transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(model_path, torch_dtype=torch_dtype)

    pipe = QwenImagePipeline.from_pretrained(repo_id, transformer=transformer, torch_dtype=torch_dtype)

    if get_gpu_memory() > 18:
        pipe.enable_model_cpu_offload()
    else:
        transformer.set_offload(True, use_pin_memory=True, num_blocks_on_gpu=20)
        pipe._exclude_from_cpu_offload.append("transformer")
        pipe.enable_sequential_cpu_offload()

    run_pipeline(
        dataset=dataset,
        batch_size=batch_size,
        pipeline=pipe,
        save_dir=save_dir_nunchaku,
        forward_kwargs={
            "width": width,
            "height": height,
            "num_inference_steps": num_inference_steps,
            "true_cfg_scale": true_cfg_scale,
        },
    )
    del transformer
    del pipe
    gc.collect()
    torch.cuda.empty_cache()

    lpips = compute_lpips(save_dir_16bit, save_dir_nunchaku)
    print(f"lpips: {lpips}")
    assert lpips < expected_lpips * 1.10