Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
58b09ee1
Commit
58b09ee1
authored
Feb 21, 2023
by
Max Ryabinin
Browse files
[WIP] Implement proper serialization of Linear8bitLt
parent
0f5c3948
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
3 deletions
+81
-3
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+28
-0
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+53
-3
No files found.
bitsandbytes/nn/modules.py
View file @
58b09ee1
...
@@ -224,6 +224,34 @@ class Linear8bitLt(nn.Linear):
...
@@ -224,6 +224,34 @@ class Linear8bitLt(nn.Linear):
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name
=
"SCB"
# case 1: .cuda was called, SCB is in self.weight
param_from_weight
=
getattr
(
self
.
weight
,
weight_name
)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state
=
getattr
(
self
.
state
,
weight_name
)
key_name
=
prefix
+
f
"
{
weight_name
}
"
if
param_from_weight
is
not
None
:
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
elif
not
self
.
state
.
has_fp16_weights
and
param_from_state
is
not
None
:
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
key
in
unexpected_keys
:
input_name
=
key
[
len
(
prefix
):]
if
input_name
==
"SCB"
:
input_param
=
state_dict
[
key
]
self
.
weight
.
SCB
.
copy_
(
input_param
)
unexpected_keys
.
remove
(
key
)
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
SCB
=
self
.
weight
.
SCB
self
.
state
.
SCB
=
self
.
weight
.
SCB
...
...
tests/test_linear8bitlt.py
View file @
58b09ee1
import
bitsandbytes
as
bnb
from
copy
import
deepcopy
import
pytest
import
pytest
import
torch
import
torch
from
bitsandbytes
import
functional
as
F
import
bitsandbytes
as
bnb
from
bitsandbytes
import
functional
as
F
from
bitsandbytes.autograd
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.autograd
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.nn.modules
import
Linear8bitLt
from
bitsandbytes.nn.modules
import
Linear8bitLt
# contributed by Alex Borzunov, see:
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
...
@@ -26,6 +29,7 @@ def test_layout_exact_match():
...
@@ -26,6 +29,7 @@ def test_layout_exact_match():
assert
restored_x
.
is_contiguous
()
assert
restored_x
.
is_contiguous
()
assert
torch
.
all
(
torch
.
eq
(
restored_x
,
x
))
assert
torch
.
all
(
torch
.
eq
(
restored_x
,
x
))
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
def
test_linear_no_igemmlt
():
def
test_linear_no_igemmlt
():
linear
=
torch
.
nn
.
Linear
(
1024
,
3072
)
linear
=
torch
.
nn
.
Linear
(
1024
,
3072
)
...
@@ -43,7 +47,7 @@ def test_linear_no_igemmlt():
...
@@ -43,7 +47,7 @@ def test_linear_no_igemmlt():
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
).
to
(
linear
.
weight
.
dtype
)
).
to
(
linear
.
weight
.
dtype
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
.
bias
=
linear
.
bias
linear
=
linear_custom
.
cuda
()
linear
_custom
=
linear_custom
.
cuda
()
linear
=
linear
.
half
().
cuda
()
linear
=
linear
.
half
().
cuda
()
x_ref
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
x_ref
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
...
@@ -59,3 +63,49 @@ def test_linear_no_igemmlt():
...
@@ -59,3 +63,49 @@ def test_linear_no_igemmlt():
assert
not
linear_custom
.
state
.
has_fp16_weights
assert
not
linear_custom
.
state
.
has_fp16_weights
assert
linear_custom
.
state
.
CB
is
not
None
assert
linear_custom
.
state
.
CB
is
not
None
assert
linear_custom
.
state
.
CxB
is
None
assert
linear_custom
.
state
.
CxB
is
None
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
[
False
,
True
])
def
test_linear_serialization
(
has_fp16_weights
):
linear
=
torch
.
nn
.
Linear
(
16
,
32
)
x
=
torch
.
randn
(
3
,
16
,
dtype
=
torch
.
half
)
linear_custom
=
Linear8bitLt
(
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
6.0
,
)
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
has_fp16_weights
).
to
(
linear
.
weight
.
dtype
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
=
linear_custom
.
cuda
()
x_first
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_first
=
linear_custom
(
x_first
).
float
()
grad_proj
=
torch
.
randn_like
(
fx_first
)
(
fx_first
*
grad_proj
).
mean
().
backward
()
state_dict
=
deepcopy
(
linear_custom
.
state_dict
())
new_linear_custom
=
Linear8bitLt
(
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
6.0
,
)
linear_custom
.
state
.
force_no_igemmlt
=
True
new_linear_custom
=
new_linear_custom
.
cuda
()
new_linear_custom
.
load_state_dict
(
state_dict
,
strict
=
True
)
x_second
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_second
=
new_linear_custom
(
x_second
).
float
()
(
fx_second
*
grad_proj
).
mean
().
backward
()
assert
torch
.
allclose
(
fx_first
,
fx_second
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_first
.
grad
,
x_second
.
grad
,
atol
=
1e-5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment