test_flux_tools.py 4.1 KB
Newer Older
1
import pytest
muyangli's avatar
muyangli committed
2
3
import torch

4
from nunchaku.utils import get_precision, is_turing
Muyang Li's avatar
Muyang Li committed
5

6
from .utils import run_test
muyangli's avatar
muyangli committed
7
8


muyangli's avatar
muyangli committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_canny_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-canny-dev",
        dataset_name="MJHQ-control",
        task="canny",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
24
        expected_lpips=0.076 if get_precision() == "int4" else 0.090,
muyangli's avatar
muyangli committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    )


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_depth_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-depth-dev",
        dataset_name="MJHQ-control",
        task="depth",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=10,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
Muyang Li's avatar
Muyang Li committed
43
        expected_lpips=0.137 if get_precision() == "int4" else 0.102,
muyangli's avatar
muyangli committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    )


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-fill-dev",
        dataset_name="MJHQ-control",
        task="fill",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
62
        expected_lpips=0.046 if get_precision() == "int4" else 0.021,
muyangli's avatar
muyangli committed
63
64
65
    )


muyangli's avatar
muyangli committed
66
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
67
# def test_flux_dev_canny_lora():
muyangli's avatar
muyangli committed
68
69
#     run_test(
#         precision=get_precision(),
muyangli's avatar
muyangli committed
70
#         model_name="flux.1-dev",
muyangli's avatar
muyangli committed
71
72
73
74
75
76
77
78
79
#         dataset_name="MJHQ-control",
#         task="canny",
#         dtype=torch.bfloat16,
#         height=1024,
#         width=1024,
#         num_inference_steps=30,
#         guidance_scale=30,
#         attention_impl="nunchaku-fp16",
#         cpu_offload=False,
muyangli's avatar
muyangli committed
80
#         lora_names="canny",
muyangli's avatar
muyangli committed
81
82
#         lora_strengths=0.85,
#         cache_threshold=0,
muyangli's avatar
muyangli committed
83
#         expected_lpips=0.081,
muyangli's avatar
muyangli committed
84
#     )
muyangli's avatar
muyangli committed
85
86


muyangli's avatar
muyangli committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_depth_lora():
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ-control",
        task="depth",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=10,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        lora_names="depth",
        lora_strengths=0.85,
104
        expected_lpips=0.181 if get_precision() == "int4" else 0.196,
muyangli's avatar
muyangli committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    )


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev_turbo():
    run_test(
        precision=get_precision(),
        model_name="flux.1-fill-dev",
        dataset_name="MJHQ-control",
        task="fill",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=8,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        lora_names="turbo8",
        lora_strengths=1,
125
        expected_lpips=0.036 if get_precision() == "int4" else 0.030,
muyangli's avatar
muyangli committed
126
127
128
    )


muyangli's avatar
muyangli committed
129
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
130
def test_flux_dev_redux():
131
132
133
134
135
136
137
138
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ-control",
        task="redux",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
muyangli's avatar
muyangli committed
139
        num_inference_steps=20,
140
141
142
143
        guidance_scale=2.5,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
144
        expected_lpips=(0.162 if get_precision() == "int4" else 0.466),  # not sure why the fp4 model is so different
145
    )