test_flux_schnell.py 969 Bytes
Newer Older
muyangli's avatar
muyangli committed
1
2
import pytest

3
from nunchaku.utils import get_precision, is_turing
Muyang Li's avatar
Muyang Li committed
4

5
from .utils import run_test
muyangli's avatar
muyangli committed
6
7


muyangli's avatar
muyangli committed
8
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
9
@pytest.mark.parametrize(
10
    "height,width,attention_impl,cpu_offload,expected_lpips",
muyangli's avatar
muyangli committed
11
    [
12
        (1024, 1024, "flashattn2", False, 0.141 if get_precision() == "int4" else 0.126),
Muyang Li's avatar
Muyang Li committed
13
        (1024, 1024, "nunchaku-fp16", False, 0.139 if get_precision() == "int4" else 0.126),
14
        (1920, 1080, "nunchaku-fp16", False, 0.190 if get_precision() == "int4" else 0.138),
15
        (2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
muyangli's avatar
muyangli committed
16
17
    ],
)
18
def test_flux_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
19
20
21
22
23
24
25
    run_test(
        precision=get_precision(),
        height=height,
        width=width,
        attention_impl=attention_impl,
        cpu_offload=cpu_offload,
        expected_lpips=expected_lpips,
muyangli's avatar
muyangli committed
26
    )