Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
a1c0844b
Unverified
Commit
a1c0844b
authored
Mar 05, 2024
by
rdyro
Committed by
GitHub
Mar 05, 2024
Browse files
adding whole Linear8bitLt/Linear4bit module save/load serialization (#1099)
parent
f9eba9c8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
3 deletions
+62
-3
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+3
-1
tests/test_linear4bit.py
tests/test_linear4bit.py
+26
-1
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+33
-1
No files found.
bitsandbytes/nn/modules.py
View file @
a1c0844b
...
...
@@ -449,7 +449,9 @@ class Int8Params(torch.nn.Parameter):
cls
.
SCB
=
None
if
data
is
None
:
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
obj
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
obj
.
CB
,
obj
.
SCB
=
cls
.
CB
,
cls
.
SCB
return
obj
def
cuda
(
self
,
device
):
if
self
.
has_fp16_weights
:
...
...
tests/test_linear4bit.py
View file @
a1c0844b
import
copy
from
io
import
BytesIO
import
os
import
pickle
from
tempfile
import
TemporaryDirectory
...
...
@@ -16,12 +17,24 @@ storage = {
"float32"
:
torch
.
float32
,
}
def
torch_save_to_buffer
(
obj
):
buffer
=
BytesIO
()
torch
.
save
(
obj
,
buffer
)
buffer
.
seek
(
0
)
return
buffer
def
torch_load_from_buffer
(
buffer
):
buffer
.
seek
(
0
)
obj
=
torch
.
load
(
buffer
)
buffer
.
seek
(
0
)
return
obj
@
pytest
.
mark
.
parametrize
(
"quant_storage"
,
[
"uint8"
,
"float16"
,
"bfloat16"
,
"float32"
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
TRUE_FALSE
)
@
pytest
.
mark
.
parametrize
(
"compress_statistics"
,
TRUE_FALSE
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"nf4"
,
"fp4"
])
def
test_linear_serialization
(
quant_type
,
compress_statistics
,
bias
,
quant_storage
):
@
pytest
.
mark
.
parametrize
(
"save_before_forward"
,
TRUE_FALSE
)
def
test_linear_serialization
(
quant_type
,
compress_statistics
,
bias
,
quant_storage
,
save_before_forward
):
original_dtype
=
torch
.
float16
compute_dtype
=
None
device
=
"cuda"
...
...
@@ -124,6 +137,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert
a
.
dtype
==
b
.
dtype
assert
torch
.
equal
(
a
,
b
)
if
save_before_forward
:
bytes_4bit
=
torch_save_to_buffer
(
linear_q
)
# Forward test
x
=
torch
.
rand
(
42
,
layer_shape
[
0
],
device
=
device
)
a
=
linear_q
(
x
)
...
...
@@ -136,6 +152,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert
torch
.
equal
(
a
,
b
)
assert
torch
.
equal
(
a
,
c
)
if
not
save_before_forward
:
bytes_4bit
=
torch_save_to_buffer
(
linear_q
)
linear_q3
=
torch_load_from_buffer
(
bytes_4bit
)
# Test moving to CPU and back to GPU
linear_q2
.
to
(
"cpu"
)
linear_q2
.
to
(
device
)
...
...
@@ -144,6 +164,11 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert
c
.
device
==
d
.
device
assert
torch
.
equal
(
c
,
d
)
d
=
linear_q3
(
x
)
assert
c
.
dtype
==
d
.
dtype
assert
c
.
device
==
d
.
device
assert
torch
.
equal
(
c
,
d
)
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
with
TemporaryDirectory
()
as
tmpdir
:
state_path_4bit
=
os
.
path
.
join
(
tmpdir
,
"state_4bit.pth"
)
...
...
tests/test_linear8bitlt.py
View file @
a1c0844b
from
contextlib
import
nullcontext
from
io
import
BytesIO
import
os
from
tempfile
import
TemporaryDirectory
...
...
@@ -65,12 +66,25 @@ def test_linear_no_igemmlt():
assert
linear_custom
.
state
.
CB
is
not
None
assert
linear_custom
.
state
.
CxB
is
None
def
torch_save_to_buffer
(
obj
):
buffer
=
BytesIO
()
torch
.
save
(
obj
,
buffer
)
buffer
.
seek
(
0
)
return
buffer
def
torch_load_from_buffer
(
buffer
):
buffer
.
seek
(
0
)
obj
=
torch
.
load
(
buffer
)
buffer
.
seek
(
0
)
return
obj
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_fp16_weights"
))
@
pytest
.
mark
.
parametrize
(
"serialize_before_forward"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"serialize_before_forward"
))
@
pytest
.
mark
.
parametrize
(
"deserialize_before_cuda"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"deserialize_before_cuda"
))
@
pytest
.
mark
.
parametrize
(
"force_no_igemmlt"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"force_no_igemmlt"
))
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
,
force_no_igemmlt
):
@
pytest
.
mark
.
parametrize
(
"save_before_forward"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"save_before_forward"
))
@
pytest
.
mark
.
parametrize
(
"load_before_cuda"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"load_before_cuda"
))
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
,
force_no_igemmlt
,
save_before_forward
,
load_before_cuda
):
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
x
=
torch
.
randn
(
3
,
32
,
dtype
=
torch
.
half
)
...
...
@@ -93,6 +107,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
if
serialize_before_forward
:
state_dict_8bit
=
linear_custom
.
state_dict
()
if
save_before_forward
:
bytes_8bit
=
torch_save_to_buffer
(
linear_custom
)
x_first
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_first
=
linear_custom
(
x_first
).
float
()
grad_proj
=
torch
.
randn_like
(
fx_first
)
...
...
@@ -101,6 +118,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
if
not
serialize_before_forward
:
state_dict_8bit
=
linear_custom
.
state_dict
()
if
not
save_before_forward
:
bytes_8bit
=
torch_save_to_buffer
(
linear_custom
)
with
TemporaryDirectory
()
as
tmpdir
:
state_path_8bit
=
os
.
path
.
join
(
tmpdir
,
"state_8bit.pth"
)
state_path
=
os
.
path
.
join
(
tmpdir
,
"state.pth"
)
...
...
@@ -127,16 +147,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
with
nullcontext
()
if
has_fp16_weights
else
pytest
.
raises
(
RuntimeError
):
new_linear_custom
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
if
load_before_cuda
:
new_linear_custom2
=
torch_load_from_buffer
(
bytes_8bit
)
new_linear_custom
=
new_linear_custom
.
cuda
()
if
not
deserialize_before_cuda
:
new_linear_custom
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
if
not
load_before_cuda
:
new_linear_custom2
=
torch_load_from_buffer
(
bytes_8bit
)
x_second
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_second
=
new_linear_custom
(
x_second
).
float
()
(
fx_second
*
grad_proj
).
mean
().
backward
()
x_third
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_third
=
new_linear_custom2
(
x_third
).
float
()
(
fx_third
*
grad_proj
).
mean
().
backward
()
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
if
has_fp16_weights
or
not
deserialize_before_cuda
:
assert
torch
.
allclose
(
fx_first
,
fx_second
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_first
.
grad
,
x_second
.
grad
,
atol
=
1e-5
)
assert
torch
.
allclose
(
fx_first
,
fx_third
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_first
.
grad
,
x_third
.
grad
,
atol
=
1e-5
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment