Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
9d60b3c5
Commit
9d60b3c5
authored
Aug 17, 2022
by
Tim Dettmers
Browse files
Fixed bug in Linear8bitLt, when the bias is None.
parent
b00cc913
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
4 deletions
+27
-4
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+3
-3
setup.py
setup.py
+1
-1
tests/test_modules.py
tests/test_modules.py
+23
-0
No files found.
bitsandbytes/nn/modules.py
View file @
9d60b3c5
...
@@ -248,10 +248,10 @@ class Linear8bitLt(nn.Linear):
...
@@ -248,10 +248,10 @@ class Linear8bitLt(nn.Linear):
if
self
.
weight
.
CB
is
not
None
:
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
self
.
init_8bit_state
()
if
self
.
bias
.
dtype
!=
torch
.
float16
:
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
torch
.
float16
:
self
.
bias
.
data
=
self
.
bias
.
data
.
half
()
self
.
bias
.
data
=
self
.
bias
.
data
.
half
()
# assert not self.state.has_fp16_weights
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
state
)
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
state
)
...
...
setup.py
View file @
9d60b3c5
...
@@ -18,7 +18,7 @@ def read(fname):
...
@@ -18,7 +18,7 @@ def read(fname):
setup
(
setup
(
name
=
f
"bitsandbytes"
,
name
=
f
"bitsandbytes"
,
version
=
f
"0.32.
0
"
,
version
=
f
"0.32.
1
"
,
author
=
"Tim Dettmers"
,
author
=
"Tim Dettmers"
,
author_email
=
"dettmers@cs.washington.edu"
,
author_email
=
"dettmers@cs.washington.edu"
,
description
=
"8-bit optimizers and matrix multiplication routines."
,
description
=
"8-bit optimizers and matrix multiplication routines."
,
...
...
tests/test_modules.py
View file @
9d60b3c5
...
@@ -549,3 +549,26 @@ def test_linear8bitlt_no_fp16_weights(threshold):
...
@@ -549,3 +549,26 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
def
test_linear8bitlt_fp32_bias
():
# casts model to fp16 -> int8 automatically
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
has_fp16_weights
=
False
).
cuda
()
assert
l1
.
weight
.
dtype
==
torch
.
int8
assert
l1
.
bias
.
dtype
==
torch
.
float32
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
# casts bias to fp32
o1
=
l1
(
b1
)
assert
l1
.
bias
.
dtype
==
torch
.
float16
# casts model to fp16 -> int8 automatically
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
has_fp16_weights
=
False
,
bias
=
False
).
cuda
()
assert
l1
.
weight
.
dtype
==
torch
.
int8
assert
l1
.
bias
is
None
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
o1
=
l1
(
b1
)
assert
l1
.
bias
is
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment