test_flux_dev.py 840 Bytes
Newer Older
muyangli's avatar
muyangli committed
1
2
import pytest

3
4
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
muyangli's avatar
muyangli committed
5
6


muyangli's avatar
muyangli committed
7
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
8
9
10
@pytest.mark.parametrize(
    "height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
    [
muyangli's avatar
muyangli committed
11
        (1024, 1024, 50, "flashattn2", False, 0.139),
muyangli's avatar
muyangli committed
12
        (2048, 512, 25, "nunchaku-fp16", False, 0.168),
13
14
15
16
    ],
)
def test_flux_dev(
    height: int, width: int, num_inference_steps: int, attention_impl: str, cpu_offload: bool, expected_lpips: float
muyangli's avatar
muyangli committed
17
):
18
19
20
21
22
23
24
    run_test(
        precision=get_precision(),
        model_name="flux.1-dev",
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        attention_impl=attention_impl,
muyangli's avatar
muyangli committed
25
        cpu_offload=cpu_offload,
26
        expected_lpips=expected_lpips,
muyangli's avatar
muyangli committed
27
    )