test_flux_cache.py 1.04 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
import pytest

3
4
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
muyangli's avatar
muyangli committed
5
6


muyangli's avatar
muyangli committed
7
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
8
@pytest.mark.parametrize(
9
    "cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
muyangli's avatar
muyangli committed
10
    [
11
        (0.12, 1024, 1024, 30, None, 1, 0.212 if get_precision() == "int4" else 0.144),
muyangli's avatar
muyangli committed
12
13
    ],
)
muyangli's avatar
update  
muyangli committed
14
def test_flux_dev_cache(
15
    cache_threshold: float,
muyangli's avatar
muyangli committed
16
17
18
    height: int,
    width: int,
    num_inference_steps: int,
19
20
    lora_name: str,
    lora_strength: float,
muyangli's avatar
muyangli committed
21
22
    expected_lpips: float,
):
23
24
25
26
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ" if lora_name is None else lora_name,
muyangli's avatar
muyangli committed
27
28
29
30
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=3.5,
31
32
33
34
        use_qencoder=False,
        cpu_offload=False,
        lora_names=lora_name,
        lora_strengths=lora_strength,
muyangli's avatar
muyangli committed
35
36
37
        cache_threshold=cache_threshold,
        expected_lpips=expected_lpips,
    )