test_flux_tools.py 3.51 KB
Newer Older
1
import pytest
muyangli's avatar
muyangli committed
2
3
import torch

4
from nunchaku.utils import get_precision, is_turing
Muyang Li's avatar
Muyang Li committed
5

6
from .utils import run_test
muyangli's avatar
muyangli committed
7
8


muyangli's avatar
muyangli committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_canny_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-canny-dev",
        dataset_name="MJHQ-control",
        task="canny",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
24
        expected_lpips=0.076 if get_precision() == "int4" else 0.090,
muyangli's avatar
muyangli committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    )


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_depth_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-depth-dev",
        dataset_name="MJHQ-control",
        task="depth",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=10,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
Muyang Li's avatar
Muyang Li committed
43
        expected_lpips=0.137 if get_precision() == "int4" else 0.102,
muyangli's avatar
muyangli committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    )


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev():
    run_test(
        precision=get_precision(),
        model_name="flux.1-fill-dev",
        dataset_name="MJHQ-control",
        task="fill",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
62
        expected_lpips=0.046 if get_precision() == "int4" else 0.021,
muyangli's avatar
muyangli committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    )


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_depth_lora():
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ-control",
        task="depth",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=30,
        guidance_scale=10,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        lora_names="depth",
        lora_strengths=0.85,
83
        expected_lpips=0.181 if get_precision() == "int4" else 0.196,
muyangli's avatar
muyangli committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    )


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev_turbo():
    run_test(
        precision=get_precision(),
        model_name="flux.1-fill-dev",
        dataset_name="MJHQ-control",
        task="fill",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
        num_inference_steps=8,
        guidance_scale=30,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
        lora_names="turbo8",
        lora_strengths=1,
104
        expected_lpips=0.036 if get_precision() == "int4" else 0.030,
muyangli's avatar
muyangli committed
105
106
107
    )


muyangli's avatar
muyangli committed
108
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
109
def test_flux_dev_redux():
110
111
112
113
114
115
116
117
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ-control",
        task="redux",
        dtype=torch.bfloat16,
        height=1024,
        width=1024,
muyangli's avatar
muyangli committed
118
        num_inference_steps=20,
119
120
121
122
        guidance_scale=2.5,
        attention_impl="nunchaku-fp16",
        cpu_offload=False,
        cache_threshold=0,
123
        expected_lpips=(0.162 if get_precision() == "int4" else 0.466),  # not sure why the fp4 model is so different
124
    )