Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
a4875fc0
Unverified
Commit
a4875fc0
authored
Aug 14, 2024
by
Matthew Douglas
Committed by
GitHub
Aug 14, 2024
Browse files
Bugfix: Load correct nocublaslt library variant when BNB_CUDA_VERSION override is set (#1318)
parent
6d714a5c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
7 deletions
+8
-7
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+2
-7
tests/test_cuda_setup_evaluator.py
tests/test_cuda_setup_evaluator.py
+6
-0
No files found.
bitsandbytes/cextension.py
View file @
a4875fc0
...
@@ -20,6 +20,7 @@ import ctypes as ct
...
@@ -20,6 +20,7 @@ import ctypes as ct
import
logging
import
logging
import
os
import
os
from
pathlib
import
Path
from
pathlib
import
Path
import
re
import
torch
import
torch
...
@@ -44,13 +45,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
...
@@ -44,13 +45,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
override_value
=
os
.
environ
.
get
(
"BNB_CUDA_VERSION"
)
override_value
=
os
.
environ
.
get
(
"BNB_CUDA_VERSION"
)
if
override_value
:
if
override_value
:
library_name_stem
,
_
,
library_name_ext
=
library_name
.
rpartition
(
"."
)
library_name
=
re
.
sub
(
"cuda\d+"
,
f
"cuda
{
override_value
}
"
,
library_name
,
count
=
1
)
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda118`;
# let's remove any trailing numbers:
library_name_stem
=
library_name_stem
.
rstrip
(
"0123456789"
)
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda`;
# let's tack the new version number and the original extension back on.
library_name
=
f
"
{
library_name_stem
}{
override_value
}
.
{
library_name_ext
}
"
logger
.
warning
(
logger
.
warning
(
f
"WARNING: BNB_CUDA_VERSION=
{
override_value
}
environment variable detected; loading
{
library_name
}
.
\n
"
f
"WARNING: BNB_CUDA_VERSION=
{
override_value
}
environment variable detected; loading
{
library_name
}
.
\n
"
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
\n
"
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
\n
"
...
...
tests/test_cuda_setup_evaluator.py
View file @
a4875fc0
...
@@ -33,6 +33,12 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
...
@@ -33,6 +33,12 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
assert
"BNB_CUDA_VERSION"
in
caplog
.
text
# did we get the warning?
assert
"BNB_CUDA_VERSION"
in
caplog
.
text
# did we get the warning?
def
test_get_cuda_bnb_library_path_override_nocublaslt
(
monkeypatch
,
cuda111_noblas_spec
,
caplog
):
monkeypatch
.
setenv
(
"BNB_CUDA_VERSION"
,
"125"
)
assert
get_cuda_bnb_library_path
(
cuda111_noblas_spec
).
stem
==
"libbitsandbytes_cuda125_nocublaslt"
assert
"BNB_CUDA_VERSION"
in
caplog
.
text
# did we get the warning?
def
test_get_cuda_bnb_library_path_nocublaslt
(
monkeypatch
,
cuda111_noblas_spec
):
def
test_get_cuda_bnb_library_path_nocublaslt
(
monkeypatch
,
cuda111_noblas_spec
):
monkeypatch
.
delenv
(
"BNB_CUDA_VERSION"
,
raising
=
False
)
monkeypatch
.
delenv
(
"BNB_CUDA_VERSION"
,
raising
=
False
)
assert
get_cuda_bnb_library_path
(
cuda111_noblas_spec
).
stem
==
"libbitsandbytes_cuda111_nocublaslt"
assert
get_cuda_bnb_library_path
(
cuda111_noblas_spec
).
stem
==
"libbitsandbytes_cuda111_nocublaslt"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment