test_flux_dev_loras.py 3.38 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
import pytest

3
4
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
muyangli's avatar
muyangli committed
5
6


muyangli's avatar
muyangli committed
7
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
8
@pytest.mark.parametrize(
9
    "num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
muyangli's avatar
muyangli committed
10
    [
muyangli's avatar
muyangli committed
11
12
        (25, "realism", 0.9, True, 0.136),
        (25, "ghibsky", 1, False, 0.186),
muyangli's avatar
muyangli committed
13
        # (28, "anime", 1, False, 0.284),
muyangli's avatar
muyangli committed
14
        (24, "sketch", 1, True, 0.260),
muyangli's avatar
muyangli committed
15
16
        # (28, "yarn", 1, False, 0.211),
        # (25, "haunted_linework", 1, True, 0.317),
muyangli's avatar
muyangli committed
17
18
    ],
)
19
20
21
22
23
def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offload, expected_lpips):
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name=lora_name,
muyangli's avatar
muyangli committed
24
25
26
27
28
        height=1024,
        width=1024,
        num_inference_steps=num_inference_steps,
        guidance_scale=3.5,
        use_qencoder=False,
muyangli's avatar
muyangli committed
29
        attention_impl="nunchaku-fp16",
muyangli's avatar
muyangli committed
30
        cpu_offload=cpu_offload,
31
32
        lora_names=lora_name,
        lora_strengths=lora_strength,
muyangli's avatar
muyangli committed
33
34
35
36
37
        cache_threshold=0,
        expected_lpips=expected_lpips,
    )


muyangli's avatar
muyangli committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
# def test_flux_dev_hypersd8_1536x2048():
#     run_test(
#         precision=get_precision(),
#         model_name="flux.1-dev",
#         dataset_name="MJHQ",
#         height=1536,
#         width=2048,
#         num_inference_steps=8,
#         guidance_scale=3.5,
#         use_qencoder=False,
#         attention_impl="nunchaku-fp16",
#         cpu_offload=True,
#         lora_names="hypersd8",
#         lora_strengths=0.125,
#         cache_threshold=0,
#         expected_lpips=0.164,
#     )
56
57


muyangli's avatar
muyangli committed
58
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
59
def test_flux_dev_turbo8_1024x1920():
60
61
62
63
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ",
muyangli's avatar
muyangli committed
64
65
        height=1024,
        width=1920,
66
67
68
69
70
71
72
73
        num_inference_steps=8,
        guidance_scale=3.5,
        use_qencoder=False,
        attention_impl="nunchaku-fp16",
        cpu_offload=True,
        lora_names="turbo8",
        lora_strengths=1,
        cache_threshold=0,
muyangli's avatar
muyangli committed
74
        expected_lpips=0.151,
75
76
77
    )


muyangli's avatar
muyangli committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
# def test_flux_dev_turbo8_yarn_2048x1024():
#     run_test(
#         precision=get_precision(),
#         model_name="flux.1-dev",
#         dataset_name="yarn",
#         height=2048,
#         width=1024,
#         num_inference_steps=8,
#         guidance_scale=3.5,
#         use_qencoder=False,
#         cpu_offload=True,
#         lora_names=["turbo8", "yarn"],
#         lora_strengths=[1, 1],
#         cache_threshold=0,
#         expected_lpips=0.255,
#     )
95
96


muyangli's avatar
muyangli committed
97
# lora composition & large rank loras
muyangli's avatar
muyangli committed
98
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
99
100
101
102
def test_flux_dev_turbo8_yarn_1024x1024():
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
muyangli's avatar
muyangli committed
103
        dataset_name="haunted_linework",
104
105
106
107
108
109
110
        height=1024,
        width=1024,
        num_inference_steps=8,
        guidance_scale=3.5,
        use_qencoder=False,
        cpu_offload=True,
        lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"],
muyangli's avatar
muyangli committed
111
        lora_strengths=[0, 0, 0, 0, 0, 1, 1],
muyangli's avatar
muyangli committed
112
        cache_threshold=0,
muyangli's avatar
muyangli committed
113
        expected_lpips=0.310,
muyangli's avatar
muyangli committed
114
    )