test_flux_cache.py 1.04 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
import pytest

3
4
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
muyangli's avatar
muyangli committed
5
6


7
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs")
muyangli's avatar
muyangli committed
8
@pytest.mark.parametrize(
9
    "cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
muyangli's avatar
muyangli committed
10
    [
11
12
        (0.12, 1024, 1024, 30, None, 1, 0.26),
        (0.12, 512, 2048, 30, "anime", 1, 0.4),
muyangli's avatar
muyangli committed
13
14
    ],
)
15
16
def test_flux_dev_loras(
    cache_threshold: float,
muyangli's avatar
muyangli committed
17
18
19
    height: int,
    width: int,
    num_inference_steps: int,
20
21
    lora_name: str,
    lora_strength: float,
muyangli's avatar
muyangli committed
22
23
    expected_lpips: float,
):
24
25
26
27
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ" if lora_name is None else lora_name,
muyangli's avatar
muyangli committed
28
29
30
31
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=3.5,
32
33
34
35
        use_qencoder=False,
        cpu_offload=False,
        lora_names=lora_name,
        lora_strengths=lora_strength,
muyangli's avatar
muyangli committed
36
37
38
        cache_threshold=cache_threshold,
        expected_lpips=expected_lpips,
    )