Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
a4875fc0
Unverified
Commit
a4875fc0
authored
Aug 14, 2024
by
Matthew Douglas
Committed by
GitHub
Aug 14, 2024
Browse files
Bugfix: Load correct nocublaslt library variant when BNB_CUDA_VERSION override is set (#1318)
parent
6d714a5c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
7 deletions
+8
-7
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+2
-7
tests/test_cuda_setup_evaluator.py
tests/test_cuda_setup_evaluator.py
+6
-0
No files found.
bitsandbytes/cextension.py
View file @
a4875fc0
...
...
@@ -20,6 +20,7 @@ import ctypes as ct
import
logging
import
os
from
pathlib
import
Path
import
re
import
torch
...
...
@@ -44,13 +45,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
override_value
=
os
.
environ
.
get
(
"BNB_CUDA_VERSION"
)
if
override_value
:
library_name_stem
,
_
,
library_name_ext
=
library_name
.
rpartition
(
"."
)
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda118`;
# let's remove any trailing numbers:
library_name_stem
=
library_name_stem
.
rstrip
(
"0123456789"
)
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda`;
# let's tack the new version number and the original extension back on.
library_name
=
f
"
{
library_name_stem
}{
override_value
}
.
{
library_name_ext
}
"
library_name
=
re
.
sub
(
"cuda\d+"
,
f
"cuda
{
override_value
}
"
,
library_name
,
count
=
1
)
logger
.
warning
(
f
"WARNING: BNB_CUDA_VERSION=
{
override_value
}
environment variable detected; loading
{
library_name
}
.
\n
"
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
\n
"
...
...
tests/test_cuda_setup_evaluator.py
View file @
a4875fc0
...
...
@@ -33,6 +33,12 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
assert
"BNB_CUDA_VERSION"
in
caplog
.
text
# did we get the warning?
def
test_get_cuda_bnb_library_path_override_nocublaslt
(
monkeypatch
,
cuda111_noblas_spec
,
caplog
):
monkeypatch
.
setenv
(
"BNB_CUDA_VERSION"
,
"125"
)
assert
get_cuda_bnb_library_path
(
cuda111_noblas_spec
).
stem
==
"libbitsandbytes_cuda125_nocublaslt"
assert
"BNB_CUDA_VERSION"
in
caplog
.
text
# did we get the warning?
def
test_get_cuda_bnb_library_path_nocublaslt
(
monkeypatch
,
cuda111_noblas_spec
):
monkeypatch
.
delenv
(
"BNB_CUDA_VERSION"
,
raising
=
False
)
assert
get_cuda_bnb_library_path
(
cuda111_noblas_spec
).
stem
==
"libbitsandbytes_cuda111_nocublaslt"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment