test_flux_tools.py 3.94 KB
Newer Older
1
import pytest
muyangli's avatar
muyangli committed
2
3
import torch

4
5
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
muyangli's avatar
muyangli committed
6
7


muyangli's avatar
muyangli committed
8
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
9
10
11
12
13
14
15
16
17
def test_flux_canny_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-canny-dev",
        dataset_name="MJHQ-control",
        task="canny",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
muyangli's avatar
muyangli committed
18
        num_inference_steps=30,
19
20
21
22
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
muyangli's avatar
muyangli committed
23
        expected_lpips=0.076 if get_precision() == "int4" else 0.164,
muyangli's avatar
muyangli committed
24
25
26
    )


muyangli's avatar
muyangli committed
27
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def test_flux_depth_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-depth-dev",
        dataset_name="MJHQ-control",
        task="depth",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=10,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
muyangli's avatar
muyangli committed
42
        expected_lpips=0.137 if get_precision() == "int4" else 0.120,
43
    )
muyangli's avatar
muyangli committed
44
45


muyangli's avatar
muyangli committed
46
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
47
48
49
50
51
52
53
54
55
def test_flux_fill_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-fill-dev",
        dataset_name="MJHQ-control",
        task="fill",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
muyangli's avatar
muyangli committed
56
        num_inference_steps=30,
57
58
59
60
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
muyangli's avatar
muyangli committed
61
        expected_lpips=0.046,
62
    )
muyangli's avatar
muyangli committed
63
64


muyangli's avatar
muyangli committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
# def test_flux_dev_canny_lora():
#     run_test(
#         precision=get_precision(),
#         model_name="flux.1-dev",
#         dataset_name="MJHQ-control",
#         task="canny",
#         dtype=torch.bfloat16,
#         height=1024,
#         width=1024,
#         num_inference_steps=30,
#         guidance_scale=30,
#         attention_impl="nunchaku-fp16",
#         cpu_offload=False,
#         lora_names="canny",
#         lora_strengths=0.85,
#         cache_threshold=0,
#         expected_lpips=0.081,
#     )
muyangli's avatar
muyangli committed
84
85


muyangli's avatar
muyangli committed
86
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
def test_flux_dev_depth_lora():
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ-control",
        task="depth",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=10,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        lora_names="depth",
        lora_strengths=0.85,
muyangli's avatar
muyangli committed
103
        expected_lpips=0.181,
104
    )
muyangli's avatar
muyangli committed
105
106


muyangli's avatar
muyangli committed
107
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
108
109
110
111
112
113
114
def test_flux_fill_dev_turbo():
    run_test(
        precision=get_precision(),
        model_name="flux.1-fill-dev",
        dataset_name="MJHQ-control",
        task="fill",
        dtype=torch.bfloat16,
muyangli's avatar
muyangli committed
115
116
        height=1024,
        width=1024,
117
        num_inference_steps=8,
muyangli's avatar
muyangli committed
118
        guidance_scale=30,
119
120
121
122
123
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        lora_names="turbo8",
        lora_strengths=1,
muyangli's avatar
muyangli committed
124
        expected_lpips=0.036,
125
    )
muyangli's avatar
muyangli committed
126
127


muyangli's avatar
muyangli committed
128
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
129
def test_flux_dev_redux():
130
131
132
133
134
135
136
137
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ-control",
        task="redux",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
muyangli's avatar
muyangli committed
138
        num_inference_steps=20,
139
140
141
142
        guidance_scale=2.5,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
muyangli's avatar
muyangli committed
143
        expected_lpips=(0.162 if get_precision() == "int4" else 0.198),
144
    )