test_flux_tools.py 3.91 KB
Newer Older
1
import pytest
muyangli's avatar
muyangli committed
2
3
import torch

4
5
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
muyangli's avatar
muyangli committed
6
7


8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_canny_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-canny-dev",
        dataset_name="MJHQ-control",
        task="canny",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=50,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        expected_lpips=0.103 if get_precision() == "int4" else 0.164,
muyangli's avatar
muyangli committed
24
25
26
    )


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_depth_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-depth-dev",
        dataset_name="MJHQ-control",
        task="depth",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=10,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
Muyang Li's avatar
Muyang Li committed
42
        expected_lpips=0.170 if get_precision() == "int4" else 0.120,
43
    )
muyangli's avatar
muyangli committed
44
45


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_fill_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-fill-dev",
        dataset_name="MJHQ-control",
        task="fill",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=50,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        expected_lpips=0.045,
    )
muyangli's avatar
muyangli committed
63
64


65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_canny_lora():
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ-control",
        task="canny",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=50,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        lora_names="canny",
        lora_strengths=0.85,
        cache_threshold=0,
        expected_lpips=0.103,
muyangli's avatar
muyangli committed
83
84
85
    )


86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_depth_lora():
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ-control",
        task="depth",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=10,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        lora_names="depth",
        lora_strengths=0.85,
        expected_lpips=0.163,
    )
muyangli's avatar
muyangli committed
105
106


107
108
109
110
111
112
113
114
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_fill_dev_turbo():
    run_test(
        precision=get_precision(),
        model_name="flux.1-fill-dev",
        dataset_name="MJHQ-control",
        task="fill",
        dtype=torch.bfloat16,
muyangli's avatar
muyangli committed
115
116
        height=1024,
        width=1024,
117
        num_inference_steps=8,
muyangli's avatar
muyangli committed
118
        guidance_scale=30,
119
120
121
122
123
124
125
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        lora_names="turbo8",
        lora_strengths=1,
        expected_lpips=0.048,
    )
muyangli's avatar
muyangli committed
126
127


128
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
muyangli's avatar
muyangli committed
129
def test_flux_dev_redux():
130
131
132
133
134
135
136
137
138
139
140
141
142
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ-control",
        task="redux",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=50,
        guidance_scale=2.5,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
Muyang Li's avatar
Muyang Li committed
143
        expected_lpips=0.198 if get_precision() == "int4" else 0.55,  # redux seems to generate different images on 5090
144
    )