test_flux1_depth_dev.py 5.91 KB
Newer Older
Muyang Li's avatar
Muyang Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gc
import os
from pathlib import Path

import pytest
import torch
from diffusers import FluxControlPipeline
from diffusers.utils import load_image

from nunchaku import NunchakuFluxTransformer2DModelV2
from nunchaku.utils import get_gpu_memory, get_precision, is_turing

from ...utils import already_generate, compute_lpips
from ..utils import run_pipeline

precision = get_precision()
torch_dtype = torch.float16 if is_turing() else torch.bfloat16
dtype_str = "fp16" if torch_dtype == torch.float16 else "bf16"


class Case:
    def __init__(
        self,
        rank: int = 32,
        batch_size: int = 1,
        width: int = 1024,
        height: int = 1024,
        num_inference_steps: int = 20,
        attention_impl: str = "flashattn2",
        expected_lpips: dict[str, float] = {},
        model_name: str = "flux.1-depth-dev",
        repo_id: str = "black-forest-labs/FLUX.1-Depth-dev",
    ):
        self.rank = rank
        self.batch_size = batch_size
        self.width = width
        self.height = height
        self.num_inference_steps = num_inference_steps
        self.attention_impl = attention_impl
        self.expected_lpips = expected_lpips
        self.model_name = model_name
        self.repo_id = repo_id

        self.model_path = (
            f"nunchaku-tech/nunchaku-flux.1-depth-dev/svdq-{precision}_r{rank}-flux.1-depth-dev.safetensors"
        )

        ref_root = os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref"))
        folder_name = f"w{width}h{height}t{num_inference_steps}"

        self.save_dir_16bit = Path(ref_root) / model_name / dtype_str / folder_name
        self.save_dir_nunchaku = (
            Path("test_results")
            / "nunchaku"
            / model_name
            / f"{precision}_r{rank}-{dtype_str}"
            / f"{folder_name}-bs{batch_size}"
        )

        self.pipeline_cls = FluxControlPipeline
        self.forward_kwargs = {
            "width": width,
            "height": height,
            "num_inference_steps": num_inference_steps,
            "guidance_scale": 10,
        }


@pytest.mark.parametrize(
    "case", [pytest.param(Case(expected_lpips={"int4-bf16": 0.13, "fp4-bf16": 0.11}), id="flux.1-depth-dev-r32")]
)
def test_flux_depth_dev(case: Case):
    batch_size = case.batch_size
    expected_lpips = case.expected_lpips
    repo_id = case.repo_id

    dataset = [
        {
            "prompt": "the insanely extreme muscle car, Big foot wheels, dragster style, flames, 6 wheels ",
            "filename": "1ce4f3b8627ab16e8f09e6e169d8744d32274880",
            "control_image": load_image(
                "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/inputs/1ce4f3b8627ab16e8f09e6e169d8744d32274880-depth.png"
            ).convert("RGB"),
        },
        {
            "prompt": "sunlower, Folk Art ",
            "filename": "8c2fef24a984d4c76bebcfa406b7240fd25d7c36",
            "control_image": load_image(
                "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/inputs/8c2fef24a984d4c76bebcfa406b7240fd25d7c36-depth.png"
            ).convert("RGB"),
        },
        # {
        #     "prompt": "modern realistic allium flowers, clean straight lines, black and white, a lot of white space to color, coloring book style ",
        #     "filename": "94f2b6fc3ab734ccdf6e57f72287f0a6df522dc0",
        #     "control_image": load_image(
        #         "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/inputs/94f2b6fc3ab734ccdf6e57f72287f0a6df522dc0-depth.png"
        #     ).convert("RGB"),
        # },
        # {
        #     "prompt": " Content Spirit Wraith Coin Medium engraved metallic coin Style symmetrical, detailed design Lighting Reflective natural light Colors purples and grays Composition the beast centered, surrounded by elemental symbols, stats, and abilities Create a Spirit Wraith Elemental Guardian Coin featuring a symmetrical, detailed design of the Spirit Wraith guardian at the center, signifying its affinity for the spirit element. The coin should have reflective natural light with mystical purples and ethereal grays. Encircle the guardian with elemental symbols, stats, and abilities relevant to its spiritbased prowess. ",
        #     "filename": "d38575d92bfd143930c4e57daa69aad5a4be48a6",
        #     "control_image": load_image(
        #         "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/inputs/d38575d92bfd143930c4e57daa69aad5a4be48a6-depth.png"
        #     ).convert("RGB"),
        # },
    ]

    if not already_generate(case.save_dir_16bit, len(dataset)):
        pipeline = case.pipeline_cls.from_pretrained(case.repo_id, torch_dtype=torch_dtype)
        if get_gpu_memory() > 25:
            pipeline.enable_model_cpu_offload()
        else:
            pipeline.enable_sequential_cpu_offload()
        run_pipeline(
            dataset=dataset,
            batch_size=case.batch_size,
            pipeline=pipeline,
            save_dir=case.save_dir_16bit,
            forward_kwargs=case.forward_kwargs,
        )

    transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(case.model_path, torch_dtype=torch_dtype)

    pipe = case.pipeline_cls.from_pretrained(repo_id, transformer=transformer, torch_dtype=torch_dtype)

    if get_gpu_memory() > 18:
        pipe.enable_model_cpu_offload()
    else:
        transformer.set_offload(True, use_pin_memory=True, num_blocks_on_gpu=20)
        pipe._exclude_from_cpu_offload.append("transformer")
        pipe.enable_sequential_cpu_offload()

    run_pipeline(
        dataset=dataset,
        batch_size=batch_size,
        pipeline=pipe,
        save_dir=case.save_dir_nunchaku,
        forward_kwargs=case.forward_kwargs,
    )
    del transformer
    del pipe
    gc.collect()
    torch.cuda.empty_cache()

    lpips = compute_lpips(case.save_dir_16bit, case.save_dir_nunchaku)
    print(f"lpips: {lpips}")
    assert lpips < expected_lpips[f"{precision}-{dtype_str}"] * 1.10