Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c0c352b3
"src/vscode:/vscode.git/clone" did not exist on "73b59f5203b5df71175dfd71f613b9bd380b4531"
Commit
c0c352b3
authored
Feb 05, 2023
by
Tim Dettmers
Browse files
Added bias test for LinearFP4 and basic test.
parent
c361f842
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
35 deletions
+16
-35
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+3
-3
tests/test_modules.py
tests/test_modules.py
+12
-31
No files found.
bitsandbytes/nn/__init__.py
View file @
c0c352b3
...
...
@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
LinearFP4
bitsandbytes/nn/modules.py
View file @
c0c352b3
...
...
@@ -188,9 +188,9 @@ class LinearFP4(nn.Linear):
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
x
.
dtype
:
self
.
bias
.
data
=
self
.
bias
.
data
.
to
(
x
.
dtype
)
if
getattr
(
self
.
weight
,
'state'
,
None
)
is
None
:
print
(
'FP4 state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.'
)
out
=
bnb
.
matmul_fp
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
weight
.
state
)
if
getattr
(
self
.
weight
,
'
quant_
state'
,
None
)
is
None
:
print
(
'FP4
quantization
state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.'
)
out
=
bnb
.
matmul_fp
4
(
x
,
self
.
weight
.
t
()
,
bias
=
self
.
bias
,
quant_
state
=
self
.
weight
.
quant_
state
)
return
out
...
...
tests/test_modules.py
View file @
c0c352b3
...
...
@@ -330,12 +330,8 @@ def test_linear8bitlt_inference(threshold):
def
test_linear8bitlt_accumulated_gradient
():
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
)
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
)
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
...
...
@@ -376,21 +372,10 @@ def test_linear8bitlt_accumulated_gradient():
torch
.
testing
.
assert_allclose
(
l1
[
1
].
weight
.
grad
,
l2
[
1
].
weight
.
grad
)
threshold
=
[
0.0
,
2.0
]
values
=
threshold
names
=
[
f
"threshold_
{
vals
}
"
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"threshold"
,
[
0.0
,
2.0
])
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
False
])
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
cuda
()
.
half
()
)
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
).
cuda
().
half
())
assert
l1
.
weight
.
dtype
==
torch
.
int8
l1
.
eval
()
...
...
@@ -446,13 +431,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
half
()
.
to
(
"cuda"
)
)
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
).
half
().
to
(
"cuda"
))
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
...
@@ -504,10 +483,11 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert
(
idx
==
0
).
sum
().
item
()
<=
b1
.
numel
()
*
0.005
def
test_linear8bitlt_fp32_bias
():
@
pytest
.
mark
.
parametrize
(
"module"
,
[
lambda
nin
,
nout
,
bias
=
True
:
bnb
.
nn
.
Linear8bitLt
(
nin
,
nout
,
bias
=
bias
,
has_fp16_weights
=
False
),
bnb
.
nn
.
LinearFP4
],
ids
=
[
'Int8Lt'
,
'FP4'
])
def
test_linear_kbit_fp32_bias
(
module
):
# casts model to fp16 -> int8 automatically
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
has_fp16_weights
=
False
).
cuda
()
assert
l1
.
weight
.
dtype
==
torch
.
int8
l1
=
module
(
32
,
64
).
cuda
()
assert
l1
.
weight
.
dtype
in
[
torch
.
int8
,
torch
.
u
int8
]
assert
l1
.
bias
.
dtype
==
torch
.
float32
for
i
in
range
(
100
):
...
...
@@ -517,11 +497,12 @@ def test_linear8bitlt_fp32_bias():
assert
l1
.
bias
.
dtype
==
torch
.
float16
# casts model to fp16 -> int8 automatically
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
has_fp16_weights
=
False
,
bias
=
False
).
cuda
()
assert
l1
.
weight
.
dtype
==
torch
.
int8
l1
=
module
(
32
,
64
,
bias
=
False
).
cuda
()
assert
l1
.
weight
.
dtype
in
[
torch
.
int8
,
torch
.
u
int8
]
assert
l1
.
bias
is
None
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
o1
=
l1
(
b1
)
assert
l1
.
bias
is
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment