test_flux_tools.py 4.1 KB
Newer Older
1
import pytest
muyangli's avatar
muyangli committed
2
3
import torch

4
5
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
muyangli's avatar
muyangli committed
6
7


muyangli's avatar
muyangli committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_canny_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-canny-dev",
        dataset_name="MJHQ-control",
        task="canny",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
23
        expected_lpips=0.076 if get_precision() == "int4" else 0.090,
muyangli's avatar
muyangli committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    )


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_depth_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-depth-dev",
        dataset_name="MJHQ-control",
        task="depth",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=10,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
42
        expected_lpips=0.137 if get_precision() == "int4" else 0.092,
muyangli's avatar
muyangli committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    )


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-fill-dev",
        dataset_name="MJHQ-control",
        task="fill",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
61
        expected_lpips=0.046 if get_precision() == "int4" else 0.021,
muyangli's avatar
muyangli committed
62
63
64
    )


muyangli's avatar
muyangli committed
65
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
66
# def test_flux_dev_canny_lora():
muyangli's avatar
muyangli committed
67
68
#     run_test(
#         precision=get_precision(),
muyangli's avatar
muyangli committed
69
#         model_name="flux.1-dev",
muyangli's avatar
muyangli committed
70
71
72
73
74
75
76
77
78
#         dataset_name="MJHQ-control",
#         task="canny",
#         dtype=torch.bfloat16,
#         height=1024,
#         width=1024,
#         num_inference_steps=30,
#         guidance_scale=30,
#         attention_impl="nunchaku-fp16",
#         cpu_offload=False,
muyangli's avatar
muyangli committed
79
#         lora_names="canny",
muyangli's avatar
muyangli committed
80
81
#         lora_strengths=0.85,
#         cache_threshold=0,
muyangli's avatar
muyangli committed
82
#         expected_lpips=0.081,
muyangli's avatar
muyangli committed
83
#     )
muyangli's avatar
muyangli committed
84
85


muyangli's avatar
muyangli committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_depth_lora():
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ-control",
        task="depth",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=10,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        lora_names="depth",
        lora_strengths=0.85,
103
        expected_lpips=0.181 if get_precision() == "int4" else 0.196,
muyangli's avatar
muyangli committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    )


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev_turbo():
    run_test(
        precision=get_precision(),
        model_name="flux.1-fill-dev",
        dataset_name="MJHQ-control",
        task="fill",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=8,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        lora_names="turbo8",
        lora_strengths=1,
124
        expected_lpips=0.036 if get_precision() == "int4" else 0.030,
muyangli's avatar
muyangli committed
125
126
127
    )


muyangli's avatar
muyangli committed
128
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
129
def test_flux_dev_redux():
130
131
132
133
134
135
136
137
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ-control",
        task="redux",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
muyangli's avatar
muyangli committed
138
        num_inference_steps=20,
139
140
141
142
        guidance_scale=2.5,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
143
        expected_lpips=(0.162 if get_precision() == "int4" else 0.466),  # not sure why the fp4 model is so different
144
    )