Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
ed6f3eb1
Unverified
Commit
ed6f3eb1
authored
Apr 11, 2023
by
Tim Dettmers
Committed by
GitHub
Apr 11, 2023
Browse files
Merge pull request #159 from TimDettmers/serialize_8bit
Implement proper serialization of Linear8bitLt
parents
b0ec20c3
dcecbb26
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
144 additions
and
11 deletions
+144
-11
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+10
-8
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+49
-0
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+85
-3
No files found.
bitsandbytes/autograd/_functions.py
View file @
ed6f3eb1
...
@@ -234,7 +234,7 @@ def supports_igemmlt(device: torch.device) -> bool:
...
@@ -234,7 +234,7 @@ def supports_igemmlt(device: torch.device) -> bool:
@
dataclass
@
dataclass
class
MatmulLtState
:
class
MatmulLtState
:
tile_indices
:
Optional
[
torch
.
Tensor
]
=
None
_
tile_indices
:
Optional
[
torch
.
Tensor
]
=
None
force_no_igemmlt
:
bool
=
False
force_no_igemmlt
:
bool
=
False
CB
=
None
CB
=
None
CxB
=
None
CxB
=
None
...
@@ -274,6 +274,15 @@ class MatmulLtState:
...
@@ -274,6 +274,15 @@ class MatmulLtState:
),
f
"please find this assert and manually enter tile size for
{
self
.
formatB
}
"
),
f
"please find this assert and manually enter tile size for
{
self
.
formatB
}
"
return
(
8
,
32
)
if
self
.
formatB
==
"col_turing"
else
(
32
,
32
)
return
(
8
,
32
)
if
self
.
formatB
==
"col_turing"
else
(
32
,
32
)
@
property
def
tile_indices
(
self
):
if
self
.
_tile_indices
is
None
:
device
=
self
.
CxB
.
device
transform
=
lambda
x
:
F
.
transform
(
x
.
to
(
device
),
from_order
=
"row"
,
to_order
=
self
.
formatB
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
self
.
_tile_indices
=
get_inverse_transform_indices
(
transform
,
self
.
get_tile_size
()).
to
(
device
)
return
self
.
_tile_indices
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# forward is the same, but we added the fallback for pre-turing GPUs
...
@@ -466,13 +475,6 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -466,13 +475,6 @@ class MatMul8bitLt(torch.autograd.Function):
CB
=
state
.
CB
.
to
(
ctx
.
dtype_A
,
copy
=
True
).
mul_
(
state
.
SCB
.
unsqueeze
(
1
).
mul
(
1.0
/
127.0
))
CB
=
state
.
CB
.
to
(
ctx
.
dtype_A
,
copy
=
True
).
mul_
(
state
.
SCB
.
unsqueeze
(
1
).
mul
(
1.0
/
127.0
))
grad_A
=
torch
.
matmul
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
grad_A
=
torch
.
matmul
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
elif
state
.
CxB
is
not
None
:
elif
state
.
CxB
is
not
None
:
if
state
.
tile_indices
is
None
:
order
,
tile_size
=
state
.
formatB
,
state
.
get_tile_size
()
transform
=
lambda
x
:
F
.
transform
(
x
.
cuda
(),
from_order
=
"row"
,
to_order
=
order
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
state
.
tile_indices
=
get_inverse_transform_indices
(
transform
,
tile_size
).
to
(
state
.
CxB
.
device
)
CB
=
(
CB
=
(
undo_layout
(
state
.
CxB
,
state
.
tile_indices
)
undo_layout
(
state
.
CxB
,
state
.
tile_indices
)
.
to
(
ctx
.
dtype_A
)
.
to
(
ctx
.
dtype_A
)
...
...
bitsandbytes/nn/modules.py
View file @
ed6f3eb1
...
@@ -9,6 +9,8 @@ import torch.nn.functional as F
...
@@ -9,6 +9,8 @@ import torch.nn.functional as F
from
torch
import
Tensor
,
device
,
dtype
,
nn
from
torch
import
Tensor
,
device
,
dtype
,
nn
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
from
bitsandbytes.autograd._functions
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.optim
import
GlobalOptimManager
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
...
@@ -224,6 +226,53 @@ class Linear8bitLt(nn.Linear):
...
@@ -224,6 +226,53 @@ class Linear8bitLt(nn.Linear):
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
None
and
self
.
state
.
CxB
is
not
None
:
# reorder weight layout back from ampere/turing to row
reorder_layout
=
True
weight_clone
=
self
.
weight
.
data
.
clone
()
else
:
reorder_layout
=
False
try
:
if
reorder_layout
:
self
.
weight
.
data
=
undo_layout
(
self
.
state
.
CxB
,
self
.
state
.
tile_indices
)
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name
=
"SCB"
# case 1: .cuda was called, SCB is in self.weight
param_from_weight
=
getattr
(
self
.
weight
,
weight_name
)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state
=
getattr
(
self
.
state
,
weight_name
)
key_name
=
prefix
+
f
"
{
weight_name
}
"
if
param_from_weight
is
not
None
:
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
elif
not
self
.
state
.
has_fp16_weights
and
param_from_state
is
not
None
:
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
finally
:
if
reorder_layout
:
self
.
weight
.
data
=
weight_clone
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
key
in
unexpected_keys
:
input_name
=
key
[
len
(
prefix
):]
if
input_name
==
"SCB"
:
if
self
.
weight
.
SCB
is
None
:
# buffers not yet initialized, can't call them directly without
raise
RuntimeError
(
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()"
)
input_param
=
state_dict
[
key
]
self
.
weight
.
SCB
.
copy_
(
input_param
)
unexpected_keys
.
remove
(
key
)
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
SCB
=
self
.
weight
.
SCB
self
.
state
.
SCB
=
self
.
weight
.
SCB
...
...
tests/test_linear8bitlt.py
View file @
ed6f3eb1
import
bitsandbytes
as
bnb
import
os
from
contextlib
import
nullcontext
from
itertools
import
product
from
tempfile
import
TemporaryDirectory
import
pytest
import
pytest
import
torch
import
torch
from
bitsandbytes
import
functional
as
F
import
bitsandbytes
as
bnb
from
bitsandbytes
import
functional
as
F
from
bitsandbytes.autograd
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.autograd
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.nn.modules
import
Linear8bitLt
from
bitsandbytes.nn.modules
import
Linear8bitLt
# contributed by Alex Borzunov, see:
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
...
@@ -26,6 +32,7 @@ def test_layout_exact_match():
...
@@ -26,6 +32,7 @@ def test_layout_exact_match():
assert
restored_x
.
is_contiguous
()
assert
restored_x
.
is_contiguous
()
assert
torch
.
all
(
torch
.
eq
(
restored_x
,
x
))
assert
torch
.
all
(
torch
.
eq
(
restored_x
,
x
))
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
def
test_linear_no_igemmlt
():
def
test_linear_no_igemmlt
():
linear
=
torch
.
nn
.
Linear
(
1024
,
3072
)
linear
=
torch
.
nn
.
Linear
(
1024
,
3072
)
...
@@ -43,7 +50,7 @@ def test_linear_no_igemmlt():
...
@@ -43,7 +50,7 @@ def test_linear_no_igemmlt():
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
).
to
(
linear
.
weight
.
dtype
)
).
to
(
linear
.
weight
.
dtype
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
.
bias
=
linear
.
bias
linear
=
linear_custom
.
cuda
()
linear
_custom
=
linear_custom
.
cuda
()
linear
=
linear
.
half
().
cuda
()
linear
=
linear
.
half
().
cuda
()
x_ref
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
x_ref
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
...
@@ -59,3 +66,78 @@ def test_linear_no_igemmlt():
...
@@ -59,3 +66,78 @@ def test_linear_no_igemmlt():
assert
not
linear_custom
.
state
.
has_fp16_weights
assert
not
linear_custom
.
state
.
has_fp16_weights
assert
linear_custom
.
state
.
CB
is
not
None
assert
linear_custom
.
state
.
CB
is
not
None
assert
linear_custom
.
state
.
CxB
is
None
assert
linear_custom
.
state
.
CxB
is
None
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt"
,
list
(
product
([
False
,
True
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
])))
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
,
force_no_igemmlt
):
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
x
=
torch
.
randn
(
3
,
32
,
dtype
=
torch
.
half
)
linear_custom
=
Linear8bitLt
(
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
6.0
,
)
if
force_no_igemmlt
:
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear
.
weight
.
data
.
clone
(),
requires_grad
=
has_fp16_weights
,
has_fp16_weights
=
has_fp16_weights
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
=
linear_custom
.
cuda
()
if
serialize_before_forward
:
state_dict_8bit
=
linear_custom
.
state_dict
()
x_first
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_first
=
linear_custom
(
x_first
).
float
()
grad_proj
=
torch
.
randn_like
(
fx_first
)
(
fx_first
*
grad_proj
).
mean
().
backward
()
if
not
serialize_before_forward
:
state_dict_8bit
=
linear_custom
.
state_dict
()
with
TemporaryDirectory
()
as
tmpdir
:
state_path_8bit
=
os
.
path
.
join
(
tmpdir
,
"state_8bit.pth"
)
state_path
=
os
.
path
.
join
(
tmpdir
,
"state.pth"
)
torch
.
save
(
linear
.
state_dict
(),
state_path
)
torch
.
save
(
state_dict_8bit
,
state_path_8bit
)
if
not
has_fp16_weights
:
assert
os
.
path
.
getsize
(
state_path_8bit
)
<
0.5
*
os
.
path
.
getsize
(
state_path
)
new_state_dict
=
torch
.
load
(
state_path_8bit
)
new_linear_custom
=
Linear8bitLt
(
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
6.0
,
)
if
force_no_igemmlt
:
new_linear_custom
.
state
.
force_no_igemmlt
=
True
if
deserialize_before_cuda
:
with
nullcontext
()
if
has_fp16_weights
else
pytest
.
raises
(
RuntimeError
):
new_linear_custom
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
new_linear_custom
=
new_linear_custom
.
cuda
()
if
not
deserialize_before_cuda
:
new_linear_custom
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
x_second
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_second
=
new_linear_custom
(
x_second
).
float
()
(
fx_second
*
grad_proj
).
mean
().
backward
()
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
if
has_fp16_weights
or
not
deserialize_before_cuda
:
assert
torch
.
allclose
(
fx_first
,
fx_second
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_first
.
grad
,
x_second
.
grad
,
atol
=
1e-5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment