Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
e9f3605f
Commit
e9f3605f
authored
Jun 06, 2025
by
Matthew Douglas
Browse files
Fix Linear4bit warnings/test for compute dtype
parent
812ef06a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
14 deletions
+6
-14
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+2
-2
tests/test_modules.py
tests/test_modules.py
+4
-12
No files found.
bitsandbytes/nn/modules.py
View file @
e9f3605f
...
...
@@ -455,14 +455,14 @@ class Linear4bit(nn.Linear):
self
.
compute_dtype
=
x
.
dtype
elif
x
.
dtype
==
torch
.
float16
:
# we take the compoute dtype passed into the layer
if
self
.
compute_dtype
==
torch
.
float32
and
(
x
.
numel
()
==
x
.
shape
[
-
1
]):
if
self
.
compute_dtype
in
[
None
,
torch
.
float32
]
and
(
x
.
numel
()
==
x
.
shape
[
-
1
]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this
warnings
.
warn
(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference."
,
)
warnings
.
filterwarnings
(
"ignore"
,
message
=
".*inference."
)
if
self
.
compute_dtype
==
torch
.
float32
and
(
x
.
numel
()
!=
x
.
shape
[
-
1
]):
if
self
.
compute_dtype
in
[
None
,
torch
.
float32
]
and
(
x
.
numel
()
!=
x
.
shape
[
-
1
]):
warnings
.
warn
(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed."
,
)
...
...
tests/test_modules.py
View file @
e9f3605f
...
...
@@ -440,31 +440,23 @@ def test_4bit_linear_warnings(device):
dim1
=
64
with
pytest
.
warns
(
UserWarning
,
match
=
r
"inference or training"
):
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
quant_type
=
"nf4"
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)]
)
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
quant_type
=
"nf4"
)
for
i
in
range
(
10
)])
net
=
net
.
to
(
device
)
inp
=
torch
.
rand
(
10
,
dim1
,
device
=
device
,
dtype
=
torch
.
float16
)
net
(
inp
)
with
pytest
.
warns
(
UserWarning
,
match
=
r
"inference."
):
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
quant_type
=
"nf4"
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)]
)
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
quant_type
=
"nf4"
)
for
i
in
range
(
10
)])
net
=
net
.
to
(
device
)
inp
=
torch
.
rand
(
1
,
dim1
,
device
=
device
,
dtype
=
torch
.
float16
)
net
(
inp
)
with
pytest
.
warns
(
UserWarning
)
as
record
:
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
quant_type
=
"nf4"
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)]
)
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
quant_type
=
"nf4"
)
for
i
in
range
(
10
)])
net
=
net
.
to
(
device
)
inp
=
torch
.
rand
(
10
,
dim1
,
device
=
device
,
dtype
=
torch
.
float16
)
net
(
inp
)
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
quant_type
=
"nf4"
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)]
)
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
quant_type
=
"nf4"
)
for
i
in
range
(
10
)])
net
=
net
.
to
(
device
)
inp
=
torch
.
rand
(
1
,
dim1
,
device
=
device
,
dtype
=
torch
.
float16
)
net
(
inp
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment