test_shuttle_jaguar.py 672 Bytes
Newer Older
1
2
import pytest

3
4
from .utils import run_test
from nunchaku.utils import get_precision, is_turing
5
6


muyangli's avatar
muyangli committed
7
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
8
@pytest.mark.parametrize(
muyangli's avatar
muyangli committed
9
    "height,width,attention_impl,cpu_offload,expected_lpips", [(1024, 1024, "nunchaku-fp16", False, 0.209)]
10
)
11
12
13
14
15
16
17
18
19
def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
    run_test(
        precision=get_precision(),
        model_name="shuttle-jaguar",
        height=height,
        width=width,
        attention_impl=attention_impl,
        cpu_offload=cpu_offload,
        expected_lpips=expected_lpips,
20
    )