Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
0d1b3a32
Commit
0d1b3a32
authored
May 27, 2025
by
Matthew Douglas
Browse files
Last minute pre-release changes
parent
1d4ea6ac
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
111 additions
and
83 deletions
+111
-83
bitsandbytes/backends/cuda/ops.py
bitsandbytes/backends/cuda/ops.py
+16
-14
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+8
-5
bitsandbytes/diagnostics/cuda.py
bitsandbytes/diagnostics/cuda.py
+1
-22
bitsandbytes/diagnostics/main.py
bitsandbytes/diagnostics/main.py
+83
-39
bitsandbytes/diagnostics/utils.py
bitsandbytes/diagnostics/utils.py
+1
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+2
-2
No files found.
bitsandbytes/backends/cuda/ops.py
View file @
0d1b3a32
...
@@ -445,20 +445,22 @@ def _gemv_4bit_impl(
...
@@ -445,20 +445,22 @@ def _gemv_4bit_impl(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
torch
.
_check_is_size
(
blocksize
)
torch
.
_check_is_size
(
blocksize
)
torch
.
_check
(
A
.
numel
()
==
A
.
size
(
-
1
),
# Note: these checks are not strictly necessary, and cost more than they are worth, so they are commented out for now.
lambda
:
f
"A must be a vector with leading dimensions of 1, got
{
A
.
shape
}
"
,
# torch._check(
)
# A.numel() == A.size(-1),
torch
.
_check
(
# lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}",
A
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
# )
lambda
:
f
"A must be float16, bfloat16, or float32, got
{
A
.
dtype
}
"
,
# torch._check(
)
# A.dtype in [torch.float16, torch.bfloat16, torch.float32],
torch
.
_check
(
# lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
B
.
dtype
in
[
torch
.
uint8
,
torch
.
bfloat16
,
torch
.
float16
,
torch
.
float32
],
# )
lambda
:
f
"B must be backed by storage of type uint8, bfloat16, float16, or float32, got
{
B
.
dtype
}
"
,
# torch._check(
)
# B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
torch
.
_check
(
absmax
.
dtype
==
torch
.
float32
,
lambda
:
f
"absmax must be float32, got
{
absmax
.
dtype
}
"
)
# lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
torch
.
_check
(
code
.
dtype
==
torch
.
float32
,
lambda
:
f
"code must be float32, got
{
code
.
dtype
}
"
)
# )
# torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
# torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
m
=
ct
.
c_int32
(
shapeB
[
0
])
m
=
ct
.
c_int32
(
shapeB
[
0
])
n
=
ct
.
c_int32
(
1
)
n
=
ct
.
c_int32
(
1
)
...
...
bitsandbytes/cextension.py
View file @
0d1b3a32
import
ctypes
as
ct
import
ctypes
as
ct
import
functools
import
logging
import
logging
import
os
import
os
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -29,10 +30,8 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
...
@@ -29,10 +30,8 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
library_name
=
re
.
sub
(
r
"cuda\d+"
,
f
"cuda
{
override_value
}
"
,
library_name
,
count
=
1
)
library_name
=
re
.
sub
(
r
"cuda\d+"
,
f
"cuda
{
override_value
}
"
,
library_name
,
count
=
1
)
logger
.
warning
(
logger
.
warning
(
f
"WARNING: BNB_CUDA_VERSION=
{
override_value
}
environment variable detected; loading
{
library_name
}
.
\n
"
f
"WARNING: BNB_CUDA_VERSION=
{
override_value
}
environment variable detected; loading
{
library_name
}
.
\n
"
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
\n
"
"This can be used to load a bitsandbytes version
built with a CUDA version
that is different from the PyTorch CUDA version.
\n
"
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
\n
"
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
\n
"
"If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
\n
"
"For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64
\n
"
,
)
)
return
PACKAGE_DIR
/
library_name
return
PACKAGE_DIR
/
library_name
...
@@ -45,10 +44,14 @@ class BNBNativeLibrary:
...
@@ -45,10 +44,14 @@ class BNBNativeLibrary:
def
__init__
(
self
,
lib
:
ct
.
CDLL
):
def
__init__
(
self
,
lib
:
ct
.
CDLL
):
self
.
_lib
=
lib
self
.
_lib
=
lib
@
functools
.
cache
# noqa: B019
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
):
fn
=
getattr
(
self
.
_lib
,
name
,
None
)
if
fn
is
not
None
:
return
fn
def
throw_on_call
(
*
args
,
**
kwargs
):
def
throw_on_call
(
*
args
,
**
kwargs
):
if
hasattr
(
self
.
_lib
,
name
):
return
getattr
(
self
.
_lib
,
name
)(
*
args
,
**
kwargs
)
raise
RuntimeError
(
raise
RuntimeError
(
f
"Method '
{
name
}
' not available in CPU-only version of bitsandbytes.
\n
"
f
"Method '
{
name
}
' not available in CPU-only version of bitsandbytes.
\n
"
"Reinstall with GPU support or use CUDA-enabled hardware."
"Reinstall with GPU support or use CUDA-enabled hardware."
...
...
bitsandbytes/diagnostics/cuda.py
View file @
0d1b3a32
...
@@ -6,7 +6,6 @@ from pathlib import Path
...
@@ -6,7 +6,6 @@ from pathlib import Path
import
torch
import
torch
from
bitsandbytes.cextension
import
get_cuda_bnb_library_path
from
bitsandbytes.cextension
import
get_cuda_bnb_library_path
from
bitsandbytes.consts
import
NONPYTORCH_DOC_URL
from
bitsandbytes.cuda_specs
import
CUDASpecs
from
bitsandbytes.cuda_specs
import
CUDASpecs
from
bitsandbytes.diagnostics.utils
import
print_dedented
from
bitsandbytes.diagnostics.utils
import
print_dedented
...
@@ -115,25 +114,9 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
...
@@ -115,25 +114,9 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
print_dedented
(
print_dedented
(
f
"""
f
"""
Library not found:
{
binary_path
}
. Maybe you need to compile it from source?
Library not found:
{
binary_path
}
. Maybe you need to compile it from source?
If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`,
for example, `make CUDA_VERSION=113`.
The CUDA version for the compile might depend on your conda install, if using conda.
Inspect CUDA version via `conda list | grep cuda`.
"""
,
)
cuda_major
,
cuda_minor
=
cuda_specs
.
cuda_version_tuple
if
cuda_major
<
11
:
print_dedented
(
"""
WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
You will be only to use 8-bit optimizers and quantization routines!
"""
,
"""
,
)
)
print
(
f
"To manually override the PyTorch CUDA version please see:
{
NONPYTORCH_DOC_URL
}
"
)
# 7.5 is the minimum CC for int8 tensor cores
# 7.5 is the minimum CC for int8 tensor cores
if
not
cuda_specs
.
has_imma
:
if
not
cuda_specs
.
has_imma
:
print_dedented
(
print_dedented
(
...
@@ -144,10 +127,6 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
...
@@ -144,10 +127,6 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
"""
,
"""
,
)
)
# TODO:
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed
def
print_cuda_runtime_diagnostics
()
->
None
:
def
print_cuda_runtime_diagnostics
()
->
None
:
cudart_paths
=
list
(
find_cudart_libraries
())
cudart_paths
=
list
(
find_cudart_libraries
())
...
...
bitsandbytes/diagnostics/main.py
View file @
0d1b3a32
import
importlib
import
platform
import
sys
import
sys
import
traceback
import
traceback
import
torch
import
torch
from
bitsandbytes
import
__version__
as
bnb_version
from
bitsandbytes.consts
import
PACKAGE_GITHUB_URL
from
bitsandbytes.consts
import
PACKAGE_GITHUB_URL
from
bitsandbytes.cuda_specs
import
get_cuda_specs
from
bitsandbytes.cuda_specs
import
get_cuda_specs
from
bitsandbytes.diagnostics.cuda
import
(
from
bitsandbytes.diagnostics.cuda
import
(
print_cuda_diagnostics
,
print_cuda_diagnostics
,
print_cuda_runtime_diagnostics
,
)
)
from
bitsandbytes.diagnostics.utils
import
print_dedented
,
print_header
from
bitsandbytes.diagnostics.utils
import
print_dedented
,
print_header
_RELATED_PACKAGES
=
[
"accelerate"
,
"diffusers"
,
"numpy"
,
"pip"
,
"peft"
,
"safetensors"
,
"transformers"
,
"triton"
,
"trl"
,
]
def
sanity_check
():
def
sanity_check
():
from
bitsandbytes.optim
import
Adam
from
bitsandbytes.optim
import
Adam
...
@@ -27,30 +41,59 @@ def sanity_check():
...
@@ -27,30 +41,59 @@ def sanity_check():
assert
p1
!=
p2
assert
p1
!=
p2
def
get_package_version
(
name
:
str
)
->
str
:
try
:
version
=
importlib
.
metadata
.
version
(
name
)
except
importlib
.
metadata
.
PackageNotFoundError
:
version
=
"not found"
return
version
def
show_environment
():
"""Simple utility to print out environment information."""
print
(
f
"Platform:
{
platform
.
platform
()
}
"
)
if
platform
.
system
()
==
"Linux"
:
print
(
f
" libc:
{
'-'
.
join
(
platform
.
libc_ver
())
}
"
)
print
(
f
"Python:
{
platform
.
python_version
()
}
"
)
print
(
f
"PyTorch:
{
torch
.
__version__
}
"
)
print
(
f
" CUDA:
{
torch
.
version
.
cuda
or
'N/A'
}
"
)
print
(
f
" HIP:
{
torch
.
version
.
hip
or
'N/A'
}
"
)
print
(
f
" XPU:
{
getattr
(
torch
.
version
,
'xpu'
,
'N/A'
)
or
'N/A'
}
"
)
print
(
"Related packages:"
)
for
pkg
in
_RELATED_PACKAGES
:
version
=
get_package_version
(
pkg
)
print
(
f
"
{
pkg
}
:
{
version
}
"
)
def
main
():
def
main
():
print_header
(
"
"
)
print_header
(
f
"bitsandbytes v
{
bnb_version
}
"
)
print_header
(
"BUG REPORT INFORMATION"
)
show_environment
(
)
print_header
(
""
)
print_header
(
""
)
print_header
(
"OTHER"
)
cuda_specs
=
get_cuda_specs
()
cuda_specs
=
get_cuda_specs
()
print
(
"CUDA specs:"
,
cuda_specs
)
if
not
torch
.
cuda
.
is_available
():
print
(
"Torch says CUDA is not available. Possible reasons:"
)
print
(
"1. CUDA driver not installed"
)
print
(
"2. CUDA not installed"
)
print
(
"3. You have multiple conflicting CUDA libraries"
)
if
cuda_specs
:
if
cuda_specs
:
print_cuda_diagnostics
(
cuda_specs
)
print_cuda_diagnostics
(
cuda_specs
)
print_cuda_runtime_diagnostics
()
print_header
(
""
)
# TODO: There's a lot of noise in this; needs improvement.
print_header
(
"DEBUG INFO END"
)
# print_cuda_runtime_diagnostics()
print_header
(
""
)
if
not
torch
.
cuda
.
is_available
():
print
(
"PyTorch says CUDA is not available. Possible reasons:"
)
print
(
"1. CUDA driver not installed"
)
print
(
"2. Using a CPU-only PyTorch build"
)
print
(
"3. No GPU detected"
)
else
:
print
(
"Checking that the library is importable and CUDA is callable..."
)
print
(
"Checking that the library is importable and CUDA is callable..."
)
try
:
try
:
sanity_check
()
sanity_check
()
print
(
"SUCCESS!"
)
print
(
"SUCCESS!"
)
print
(
"Installation was successful!"
)
return
return
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
if
"not available in CPU-only"
in
str
(
e
):
if
"not available in CPU-only"
in
str
(
e
):
...
@@ -63,6 +106,7 @@ def main():
...
@@ -63,6 +106,7 @@ def main():
raise
e
raise
e
except
Exception
:
except
Exception
:
traceback
.
print_exc
()
traceback
.
print_exc
()
print_dedented
(
print_dedented
(
f
"""
f
"""
Above we output some debug information.
Above we output some debug information.
...
...
bitsandbytes/diagnostics/utils.py
View file @
0d1b3a32
...
@@ -3,7 +3,7 @@ import textwrap
...
@@ -3,7 +3,7 @@ import textwrap
HEADER_WIDTH
=
60
HEADER_WIDTH
=
60
def
print_header
(
txt
:
str
,
width
:
int
=
HEADER_WIDTH
,
filler
:
str
=
"
+
"
)
->
None
:
def
print_header
(
txt
:
str
,
width
:
int
=
HEADER_WIDTH
,
filler
:
str
=
"
=
"
)
->
None
:
txt
=
f
"
{
txt
}
"
if
txt
else
""
txt
=
f
"
{
txt
}
"
if
txt
else
""
print
(
txt
.
center
(
width
,
filler
))
print
(
txt
.
center
(
width
,
filler
))
...
...
bitsandbytes/functional.py
View file @
0d1b3a32
...
@@ -851,8 +851,8 @@ def dequantize_blockwise(
...
@@ -851,8 +851,8 @@ def dequantize_blockwise(
torch
.
ops
.
bitsandbytes
.
dequantize_blockwise
.
out
(
torch
.
ops
.
bitsandbytes
.
dequantize_blockwise
.
out
(
A
,
A
,
absmax
,
absmax
,
code
.
to
(
A
.
device
),
quant_state
.
code
.
to
(
A
.
device
),
blocksize
,
quant_state
.
blocksize
,
quant_state
.
dtype
,
quant_state
.
dtype
,
out
=
out
,
out
=
out
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment