Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
a1c0844b
Unverified
Commit
a1c0844b
authored
Mar 05, 2024
by
rdyro
Committed by
GitHub
Mar 05, 2024
Browse files
adding whole Linear8bitLt/Linear4bit module save/load serialization (#1099)
parent
f9eba9c8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
3 deletions
+62
-3
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+3
-1
tests/test_linear4bit.py
tests/test_linear4bit.py
+26
-1
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+33
-1
No files found.
bitsandbytes/nn/modules.py
View file @
a1c0844b
...
@@ -449,7 +449,9 @@ class Int8Params(torch.nn.Parameter):
...
@@ -449,7 +449,9 @@ class Int8Params(torch.nn.Parameter):
cls
.
SCB
=
None
cls
.
SCB
=
None
if
data
is
None
:
if
data
is
None
:
data
=
torch
.
empty
(
0
)
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
obj
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
obj
.
CB
,
obj
.
SCB
=
cls
.
CB
,
cls
.
SCB
return
obj
def
cuda
(
self
,
device
):
def
cuda
(
self
,
device
):
if
self
.
has_fp16_weights
:
if
self
.
has_fp16_weights
:
...
...
tests/test_linear4bit.py
View file @
a1c0844b
import
copy
import
copy
from
io
import
BytesIO
import
os
import
os
import
pickle
import
pickle
from
tempfile
import
TemporaryDirectory
from
tempfile
import
TemporaryDirectory
...
@@ -16,12 +17,24 @@ storage = {
...
@@ -16,12 +17,24 @@ storage = {
"float32"
:
torch
.
float32
,
"float32"
:
torch
.
float32
,
}
}
def
torch_save_to_buffer
(
obj
):
buffer
=
BytesIO
()
torch
.
save
(
obj
,
buffer
)
buffer
.
seek
(
0
)
return
buffer
def
torch_load_from_buffer
(
buffer
):
buffer
.
seek
(
0
)
obj
=
torch
.
load
(
buffer
)
buffer
.
seek
(
0
)
return
obj
@
pytest
.
mark
.
parametrize
(
"quant_storage"
,
[
"uint8"
,
"float16"
,
"bfloat16"
,
"float32"
])
@
pytest
.
mark
.
parametrize
(
"quant_storage"
,
[
"uint8"
,
"float16"
,
"bfloat16"
,
"float32"
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
TRUE_FALSE
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
TRUE_FALSE
)
@
pytest
.
mark
.
parametrize
(
"compress_statistics"
,
TRUE_FALSE
)
@
pytest
.
mark
.
parametrize
(
"compress_statistics"
,
TRUE_FALSE
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"nf4"
,
"fp4"
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"nf4"
,
"fp4"
])
def
test_linear_serialization
(
quant_type
,
compress_statistics
,
bias
,
quant_storage
):
@
pytest
.
mark
.
parametrize
(
"save_before_forward"
,
TRUE_FALSE
)
def
test_linear_serialization
(
quant_type
,
compress_statistics
,
bias
,
quant_storage
,
save_before_forward
):
original_dtype
=
torch
.
float16
original_dtype
=
torch
.
float16
compute_dtype
=
None
compute_dtype
=
None
device
=
"cuda"
device
=
"cuda"
...
@@ -124,6 +137,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
...
@@ -124,6 +137,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert
a
.
dtype
==
b
.
dtype
assert
a
.
dtype
==
b
.
dtype
assert
torch
.
equal
(
a
,
b
)
assert
torch
.
equal
(
a
,
b
)
if
save_before_forward
:
bytes_4bit
=
torch_save_to_buffer
(
linear_q
)
# Forward test
# Forward test
x
=
torch
.
rand
(
42
,
layer_shape
[
0
],
device
=
device
)
x
=
torch
.
rand
(
42
,
layer_shape
[
0
],
device
=
device
)
a
=
linear_q
(
x
)
a
=
linear_q
(
x
)
...
@@ -136,6 +152,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
...
@@ -136,6 +152,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert
torch
.
equal
(
a
,
b
)
assert
torch
.
equal
(
a
,
b
)
assert
torch
.
equal
(
a
,
c
)
assert
torch
.
equal
(
a
,
c
)
if
not
save_before_forward
:
bytes_4bit
=
torch_save_to_buffer
(
linear_q
)
linear_q3
=
torch_load_from_buffer
(
bytes_4bit
)
# Test moving to CPU and back to GPU
# Test moving to CPU and back to GPU
linear_q2
.
to
(
"cpu"
)
linear_q2
.
to
(
"cpu"
)
linear_q2
.
to
(
device
)
linear_q2
.
to
(
device
)
...
@@ -144,6 +164,11 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
...
@@ -144,6 +164,11 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert
c
.
device
==
d
.
device
assert
c
.
device
==
d
.
device
assert
torch
.
equal
(
c
,
d
)
assert
torch
.
equal
(
c
,
d
)
d
=
linear_q3
(
x
)
assert
c
.
dtype
==
d
.
dtype
assert
c
.
device
==
d
.
device
assert
torch
.
equal
(
c
,
d
)
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
with
TemporaryDirectory
()
as
tmpdir
:
with
TemporaryDirectory
()
as
tmpdir
:
state_path_4bit
=
os
.
path
.
join
(
tmpdir
,
"state_4bit.pth"
)
state_path_4bit
=
os
.
path
.
join
(
tmpdir
,
"state_4bit.pth"
)
...
...
tests/test_linear8bitlt.py
View file @
a1c0844b
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
io
import
BytesIO
import
os
import
os
from
tempfile
import
TemporaryDirectory
from
tempfile
import
TemporaryDirectory
...
@@ -65,12 +66,25 @@ def test_linear_no_igemmlt():
...
@@ -65,12 +66,25 @@ def test_linear_no_igemmlt():
assert
linear_custom
.
state
.
CB
is
not
None
assert
linear_custom
.
state
.
CB
is
not
None
assert
linear_custom
.
state
.
CxB
is
None
assert
linear_custom
.
state
.
CxB
is
None
def
torch_save_to_buffer
(
obj
):
buffer
=
BytesIO
()
torch
.
save
(
obj
,
buffer
)
buffer
.
seek
(
0
)
return
buffer
def
torch_load_from_buffer
(
buffer
):
buffer
.
seek
(
0
)
obj
=
torch
.
load
(
buffer
)
buffer
.
seek
(
0
)
return
obj
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_fp16_weights"
))
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_fp16_weights"
))
@
pytest
.
mark
.
parametrize
(
"serialize_before_forward"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"serialize_before_forward"
))
@
pytest
.
mark
.
parametrize
(
"serialize_before_forward"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"serialize_before_forward"
))
@
pytest
.
mark
.
parametrize
(
"deserialize_before_cuda"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"deserialize_before_cuda"
))
@
pytest
.
mark
.
parametrize
(
"deserialize_before_cuda"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"deserialize_before_cuda"
))
@
pytest
.
mark
.
parametrize
(
"force_no_igemmlt"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"force_no_igemmlt"
))
@
pytest
.
mark
.
parametrize
(
"force_no_igemmlt"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"force_no_igemmlt"
))
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
,
force_no_igemmlt
):
@
pytest
.
mark
.
parametrize
(
"save_before_forward"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"save_before_forward"
))
@
pytest
.
mark
.
parametrize
(
"load_before_cuda"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"load_before_cuda"
))
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
,
force_no_igemmlt
,
save_before_forward
,
load_before_cuda
):
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
x
=
torch
.
randn
(
3
,
32
,
dtype
=
torch
.
half
)
x
=
torch
.
randn
(
3
,
32
,
dtype
=
torch
.
half
)
...
@@ -93,6 +107,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
...
@@ -93,6 +107,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
if
serialize_before_forward
:
if
serialize_before_forward
:
state_dict_8bit
=
linear_custom
.
state_dict
()
state_dict_8bit
=
linear_custom
.
state_dict
()
if
save_before_forward
:
bytes_8bit
=
torch_save_to_buffer
(
linear_custom
)
x_first
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
x_first
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_first
=
linear_custom
(
x_first
).
float
()
fx_first
=
linear_custom
(
x_first
).
float
()
grad_proj
=
torch
.
randn_like
(
fx_first
)
grad_proj
=
torch
.
randn_like
(
fx_first
)
...
@@ -101,6 +118,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
...
@@ -101,6 +118,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
if
not
serialize_before_forward
:
if
not
serialize_before_forward
:
state_dict_8bit
=
linear_custom
.
state_dict
()
state_dict_8bit
=
linear_custom
.
state_dict
()
if
not
save_before_forward
:
bytes_8bit
=
torch_save_to_buffer
(
linear_custom
)
with
TemporaryDirectory
()
as
tmpdir
:
with
TemporaryDirectory
()
as
tmpdir
:
state_path_8bit
=
os
.
path
.
join
(
tmpdir
,
"state_8bit.pth"
)
state_path_8bit
=
os
.
path
.
join
(
tmpdir
,
"state_8bit.pth"
)
state_path
=
os
.
path
.
join
(
tmpdir
,
"state.pth"
)
state_path
=
os
.
path
.
join
(
tmpdir
,
"state.pth"
)
...
@@ -127,16 +147,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
...
@@ -127,16 +147,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
with
nullcontext
()
if
has_fp16_weights
else
pytest
.
raises
(
RuntimeError
):
with
nullcontext
()
if
has_fp16_weights
else
pytest
.
raises
(
RuntimeError
):
new_linear_custom
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
new_linear_custom
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
if
load_before_cuda
:
new_linear_custom2
=
torch_load_from_buffer
(
bytes_8bit
)
new_linear_custom
=
new_linear_custom
.
cuda
()
new_linear_custom
=
new_linear_custom
.
cuda
()
if
not
deserialize_before_cuda
:
if
not
deserialize_before_cuda
:
new_linear_custom
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
new_linear_custom
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
if
not
load_before_cuda
:
new_linear_custom2
=
torch_load_from_buffer
(
bytes_8bit
)
x_second
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
x_second
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_second
=
new_linear_custom
(
x_second
).
float
()
fx_second
=
new_linear_custom
(
x_second
).
float
()
(
fx_second
*
grad_proj
).
mean
().
backward
()
(
fx_second
*
grad_proj
).
mean
().
backward
()
x_third
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_third
=
new_linear_custom2
(
x_third
).
float
()
(
fx_third
*
grad_proj
).
mean
().
backward
()
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
if
has_fp16_weights
or
not
deserialize_before_cuda
:
if
has_fp16_weights
or
not
deserialize_before_cuda
:
assert
torch
.
allclose
(
fx_first
,
fx_second
,
atol
=
1e-5
)
assert
torch
.
allclose
(
fx_first
,
fx_second
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_first
.
grad
,
x_second
.
grad
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_first
.
grad
,
x_second
.
grad
,
atol
=
1e-5
)
assert
torch
.
allclose
(
fx_first
,
fx_third
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_first
.
grad
,
x_third
.
grad
,
atol
=
1e-5
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment