test_flux_dev_loras.py 1.86 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
import pytest

3
from nunchaku.utils import get_precision, is_turing
Muyang Li's avatar
Muyang Li committed
4

5
from .utils import run_test
muyangli's avatar
muyangli committed
6
7


muyangli's avatar
muyangli committed
8
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
9
@pytest.mark.parametrize(
10
    "num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
muyangli's avatar
muyangli committed
11
    [
12
        (25, "realism", 0.9, True, 0.136 if get_precision() == "int4" else 0.112),
muyangli's avatar
update  
muyangli committed
13
        # (25, "ghibsky", 1, False, 0.186),
muyangli's avatar
muyangli committed
14
        # (28, "anime", 1, False, 0.284),
15
        (24, "sketch", 1, True, 0.291 if get_precision() == "int4" else 0.182),
muyangli's avatar
muyangli committed
16
17
        # (28, "yarn", 1, False, 0.211),
        # (25, "haunted_linework", 1, True, 0.317),
muyangli's avatar
muyangli committed
18
19
    ],
)
20
21
22
23
24
def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offload, expected_lpips):
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name=lora_name,
muyangli's avatar
muyangli committed
25
26
27
28
29
        height=1024,
        width=1024,
        num_inference_steps=num_inference_steps,
        guidance_scale=3.5,
        use_qencoder=False,
muyangli's avatar
muyangli committed
30
        attention_impl="nunchaku-fp16",
muyangli's avatar
muyangli committed
31
        cpu_offload=cpu_offload,
32
33
        lora_names=lora_name,
        lora_strengths=lora_strength,
muyangli's avatar
muyangli committed
34
35
36
37
38
        cache_threshold=0,
        expected_lpips=expected_lpips,
    )


muyangli's avatar
muyangli committed
39
# lora composition & large rank loras
muyangli's avatar
muyangli committed
40
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
update  
muyangli committed
41
def test_flux_dev_turbo8_ghibsky_1024x1024():
42
43
44
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
muyangli's avatar
muyangli committed
45
        dataset_name="haunted_linework",
46
47
48
49
50
51
52
        height=1024,
        width=1024,
        num_inference_steps=8,
        guidance_scale=3.5,
        use_qencoder=False,
        cpu_offload=True,
        lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"],
muyangli's avatar
update  
muyangli committed
53
        lora_strengths=[0, 1, 0, 0, 0, 0, 1],
muyangli's avatar
muyangli committed
54
        cache_threshold=0,
55
        expected_lpips=0.310 if get_precision() == "int4" else 0.168,
muyangli's avatar
muyangli committed
56
    )