test_cuda_setup_evaluator.py 702 Bytes
Newer Older
1
import os
2
3
from pathlib import Path

Aarni Koskela's avatar
Aarni Koskela committed
4
5
6
import torch


7
# hardcoded test. Not good, but a sanity check for now
8
# TODO: improve this
9
def test_manual_override(requires_cuda):
10
11
12
13
    manual_cuda_path = str(Path('/mmfs1/home/dettmers/data/local/cuda-12.2'))

    pytorch_version = torch.version.cuda.replace('.', '')

14
    assert pytorch_version != 122  # TODO: this will never be true...
15
16

    os.environ['CUDA_HOME']='{manual_cuda_path}'
17
18
    os.environ['BNB_CUDA_VERSION']='122'
    #assert str(manual_cuda_path) in os.environ['LD_LIBRARY_PATH']
19
20
    import bitsandbytes as bnb
    loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name
21
    #assert loaded_lib == 'libbitsandbytes_cuda122.so'
22
23
24
25
26
27
28







Tom Aarsen's avatar
Tom Aarsen committed
29