test_flux_double_fb_cache.py 1.46 KB
Newer Older
1
2
3
import pytest

from nunchaku.utils import get_precision, is_turing
Muyang Li's avatar
Muyang Li committed
4

5
6
7
8
9
10
11
from .utils import run_test


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
    "use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
    [
12
        (True, 0.09, 0.12, 1024, 1024, 30, None, 1, 0.24 if get_precision() == "int4" else 0.165),
Muyang Li's avatar
Muyang Li committed
13
        (True, 0.09, 0.12, 1024, 1024, 50, None, 1, 0.24 if get_precision() == "int4" else 0.161),
14
    ],
15
)
16
def test_flux_dev_double_fb_cache(
17
    use_double_fb_cache: bool,
18
19
    residual_diff_threshold_multi: float,
    residual_diff_threshold_single: float,
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    height: int,
    width: int,
    num_inference_steps: int,
    lora_name: str,
    lora_strength: float,
    expected_lpips: float,
):
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        dataset_name="MJHQ" if lora_name is None else lora_name,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=3.5,
        use_qencoder=False,
        cpu_offload=False,
        lora_names=lora_name,
        lora_strengths=lora_strength,
        use_double_fb_cache=use_double_fb_cache,
        residual_diff_threshold_multi=residual_diff_threshold_multi,
        residual_diff_threshold_single=residual_diff_threshold_single,
        expected_lpips=expected_lpips,
    )