conftest.py 1.15 KB
Newer Older
1
import gc
Matthew Douglas's avatar
Matthew Douglas committed
2
import random
3

Matthew Douglas's avatar
Matthew Douglas committed
4
import numpy as np
5
6
7
8
import pytest
import torch


Matthew Douglas's avatar
Matthew Douglas committed
9
10
11
12
13
14
15
16
def _set_seed():
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.mps.manual_seed(0)
    np.random.seed(0)
    random.seed(0)


17
18
def pytest_runtest_call(item):
    try:
Matthew Douglas's avatar
Matthew Douglas committed
19
        _set_seed()
20
21
22
23
24
        item.runtest()
    except AssertionError as ae:
        if str(ae) == "Torch not compiled with CUDA enabled":
            pytest.skip("Torch not compiled with CUDA enabled")
        raise
25
26
27
28
29
    except RuntimeError as re:
        # CUDA-enabled Torch build, but no CUDA-capable device found
        if "Found no NVIDIA driver on your system" in str(re):
            pytest.skip("No NVIDIA driver found")
        raise
30
31


32
33
34
35
36
@pytest.hookimpl(trylast=True)
def pytest_runtest_teardown(item, nextitem):
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
Matthew Douglas's avatar
Matthew Douglas committed
37
38
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        torch.mps.empty_cache()
39
40


41
42
43
44
45
46
@pytest.fixture(scope="session")
def requires_cuda() -> bool:
    cuda_available = torch.cuda.is_available()
    if not cuda_available:
        pytest.skip("CUDA is required")
    return cuda_available