conftest.py 610 Bytes
Newer Older
1
2
3
4
5
6
7
import pytest
import torch


def pytest_runtest_call(item):
    try:
        item.runtest()
Aarni Koskela's avatar
Aarni Koskela committed
8
9
10
11
    except NotImplementedError as nie:
        if "NO_CUBLASLT" in str(nie):
            pytest.skip("CUBLASLT not available")
        raise
12
13
14
15
16
17
18
19
20
21
22
23
    except AssertionError as ae:
        if str(ae) == "Torch not compiled with CUDA enabled":
            pytest.skip("Torch not compiled with CUDA enabled")
        raise


@pytest.fixture(scope="session")
def requires_cuda() -> bool:
    cuda_available = torch.cuda.is_available()
    if not cuda_available:
        pytest.skip("CUDA is required")
    return cuda_available