test_multiple_batch.py 791 Bytes
Newer Older
muyangli's avatar
muyangli committed
1
# skip this test
muyangli's avatar
update  
muyangli committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import pytest

from nunchaku.utils import get_precision, is_turing
from .utils import run_test


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
    "height,width,attention_impl,cpu_offload,expected_lpips,batch_size",
    [
        (1024, 1024, "nunchaku-fp16", False, 0.140, 2),
        (1920, 1080, "flashattn2", False, 0.160, 4),
    ],
)
def test_int4_schnell(
    height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float, batch_size: int
):
    run_test(
        precision=get_precision(),
        height=height,
        width=width,
        attention_impl=attention_impl,
        cpu_offload=cpu_offload,
        expected_lpips=expected_lpips,
        batch_size=batch_size,
    )