test_flux_double_fb_cache.py 1.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import pytest

from nunchaku.utils import get_precision, is_turing
from .utils import run_test


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
    "use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
    [
11
        (True, 0.09, 0.12, 1024, 1024, 30, None, 1, 0.24 if get_precision() == "int4" else 0.165),
12
13
        (True, 0.09, 0.12, 1024, 1024, 50, None, 1, 0.24 if get_precision() == "int4" else 0.144),
    ],
14
)
15
def test_flux_dev_double_fb_cache(
16
    use_double_fb_cache: bool,
17
18
    residual_diff_threshold_multi: float,
    residual_diff_threshold_single: float,
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    height: int,
    width: int,
    num_inference_steps: int,
    lora_name: str,
    lora_strength: float,
    expected_lpips: float,
):
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ" if lora_name is None else lora_name,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=3.5,
        use_qencoder=False,
        cpu_offload=False,
        lora_names=lora_name,
        lora_strengths=lora_strength,
        use_double_fb_cache=use_double_fb_cache,
        residual_diff_threshold_multi=residual_diff_threshold_multi,
        residual_diff_threshold_single=residual_diff_threshold_single,
        expected_lpips=expected_lpips,
    )