Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
58b09ee1
You need to sign in or sign up before continuing.
Commit
58b09ee1
authored
Feb 21, 2023
by
Max Ryabinin
Browse files
[WIP] Implement proper serialization of Linear8bitLt
parent
0f5c3948
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
3 deletions
+81
-3
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+28
-0
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+53
-3
No files found.
bitsandbytes/nn/modules.py
View file @
58b09ee1
...
@@ -224,6 +224,34 @@ class Linear8bitLt(nn.Linear):
...
@@ -224,6 +224,34 @@ class Linear8bitLt(nn.Linear):
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name
=
"SCB"
# case 1: .cuda was called, SCB is in self.weight
param_from_weight
=
getattr
(
self
.
weight
,
weight_name
)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state
=
getattr
(
self
.
state
,
weight_name
)
key_name
=
prefix
+
f
"
{
weight_name
}
"
if
param_from_weight
is
not
None
:
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
elif
not
self
.
state
.
has_fp16_weights
and
param_from_state
is
not
None
:
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
key
in
unexpected_keys
:
input_name
=
key
[
len
(
prefix
):]
if
input_name
==
"SCB"
:
input_param
=
state_dict
[
key
]
self
.
weight
.
SCB
.
copy_
(
input_param
)
unexpected_keys
.
remove
(
key
)
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
SCB
=
self
.
weight
.
SCB
self
.
state
.
SCB
=
self
.
weight
.
SCB
...
...
tests/test_linear8bitlt.py
View file @
58b09ee1
import
bitsandbytes
as
bnb
from
copy
import
deepcopy
import
pytest
import
pytest
import
torch
import
torch
from
bitsandbytes
import
functional
as
F
import
bitsandbytes
as
bnb
from
bitsandbytes
import
functional
as
F
from
bitsandbytes.autograd
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.autograd
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.nn.modules
import
Linear8bitLt
from
bitsandbytes.nn.modules
import
Linear8bitLt
# contributed by Alex Borzunov, see:
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
...
@@ -26,6 +29,7 @@ def test_layout_exact_match():
...
@@ -26,6 +29,7 @@ def test_layout_exact_match():
assert
restored_x
.
is_contiguous
()
assert
restored_x
.
is_contiguous
()
assert
torch
.
all
(
torch
.
eq
(
restored_x
,
x
))
assert
torch
.
all
(
torch
.
eq
(
restored_x
,
x
))
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
def
test_linear_no_igemmlt
():
def
test_linear_no_igemmlt
():
linear
=
torch
.
nn
.
Linear
(
1024
,
3072
)
linear
=
torch
.
nn
.
Linear
(
1024
,
3072
)
...
@@ -43,7 +47,7 @@ def test_linear_no_igemmlt():
...
@@ -43,7 +47,7 @@ def test_linear_no_igemmlt():
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
).
to
(
linear
.
weight
.
dtype
)
).
to
(
linear
.
weight
.
dtype
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
.
bias
=
linear
.
bias
linear
=
linear_custom
.
cuda
()
linear
_custom
=
linear_custom
.
cuda
()
linear
=
linear
.
half
().
cuda
()
linear
=
linear
.
half
().
cuda
()
x_ref
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
x_ref
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
...
@@ -59,3 +63,49 @@ def test_linear_no_igemmlt():
...
@@ -59,3 +63,49 @@ def test_linear_no_igemmlt():
assert
not
linear_custom
.
state
.
has_fp16_weights
assert
not
linear_custom
.
state
.
has_fp16_weights
assert
linear_custom
.
state
.
CB
is
not
None
assert
linear_custom
.
state
.
CB
is
not
None
assert
linear_custom
.
state
.
CxB
is
None
assert
linear_custom
.
state
.
CxB
is
None
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
[
False
,
True
])
def
test_linear_serialization
(
has_fp16_weights
):
linear
=
torch
.
nn
.
Linear
(
16
,
32
)
x
=
torch
.
randn
(
3
,
16
,
dtype
=
torch
.
half
)
linear_custom
=
Linear8bitLt
(
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
6.0
,
)
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
has_fp16_weights
).
to
(
linear
.
weight
.
dtype
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
=
linear_custom
.
cuda
()
x_first
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_first
=
linear_custom
(
x_first
).
float
()
grad_proj
=
torch
.
randn_like
(
fx_first
)
(
fx_first
*
grad_proj
).
mean
().
backward
()
state_dict
=
deepcopy
(
linear_custom
.
state_dict
())
new_linear_custom
=
Linear8bitLt
(
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
6.0
,
)
linear_custom
.
state
.
force_no_igemmlt
=
True
new_linear_custom
=
new_linear_custom
.
cuda
()
new_linear_custom
.
load_state_dict
(
state_dict
,
strict
=
True
)
x_second
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_second
=
new_linear_custom
(
x_second
).
float
()
(
fx_second
*
grad_proj
).
mean
().
backward
()
assert
torch
.
allclose
(
fx_first
,
fx_second
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_first
.
grad
,
x_second
.
grad
,
atol
=
1e-5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment