test_cuda_setup_evaluator.py 1.54 KB
Newer Older
1
import pytest
2

3
4
from bitsandbytes.cextension import get_cuda_bnb_library_path
from bitsandbytes.cuda_specs import CUDASpecs
Aarni Koskela's avatar
Aarni Koskela committed
5
6


7
8
9
10
11
12
13
@pytest.fixture
def cuda120_spec() -> CUDASpecs:
    return CUDASpecs(
        cuda_version_string="120",
        highest_compute_capability=(8, 6),
        cuda_version_tuple=(12, 0),
    )
14
15


16
17
18
19
20
21
22
@pytest.fixture
def cuda111_noblas_spec() -> CUDASpecs:
    return CUDASpecs(
        cuda_version_string="111",
        highest_compute_capability=(7, 2),
        cuda_version_tuple=(11, 1),
    )
23

24
25
26
27
28
29
30
31
32
33
34
35

def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
    monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
    assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"


def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
    monkeypatch.setenv("BNB_CUDA_VERSION", "110")
    assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
    assert "BNB_CUDA_VERSION" in caplog.text  # did we get the warning?


36
37
38
39
40
41
def test_get_cuda_bnb_library_path_override_nocublaslt(monkeypatch, cuda111_noblas_spec, caplog):
    monkeypatch.setenv("BNB_CUDA_VERSION", "125")
    assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda125_nocublaslt"
    assert "BNB_CUDA_VERSION" in caplog.text  # did we get the warning?


42
43
def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
    monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
Ruff's avatar
Ruff committed
44
    assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"