Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
ac3ab281
Commit
ac3ab281
authored
Feb 25, 2023
by
Max Ryabinin
Browse files
Handle more cases in test_linear_serialization
parent
58b09ee1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
13 deletions
+40
-13
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+40
-13
No files found.
tests/test_linear8bitlt.py
View file @
ac3ab281
from
copy
import
deepcopy
import
os
from
contextlib
import
nullcontext
from
itertools
import
product
from
tempfile
import
TemporaryDirectory
import
pytest
import
torch
...
...
@@ -66,10 +69,11 @@ def test_linear_no_igemmlt():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
[
False
,
True
])
def
test_linear_serialization
(
has_fp16_weights
):
linear
=
torch
.
nn
.
Linear
(
16
,
32
)
x
=
torch
.
randn
(
3
,
16
,
dtype
=
torch
.
half
)
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights, serialize_before_forward, deserialize_before_cuda"
,
list
(
product
([
False
,
True
],
[
False
,
True
],
[
False
,
True
])))
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
):
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
x
=
torch
.
randn
(
3
,
32
,
dtype
=
torch
.
half
)
linear_custom
=
Linear8bitLt
(
linear
.
in_features
,
...
...
@@ -78,19 +82,34 @@ def test_linear_serialization(has_fp16_weights):
has_fp16_weights
=
has_fp16_weights
,
threshold
=
6.0
,
)
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
has_fp16_weights
)
.
to
(
linear
.
weight
.
dtype
)
linear
.
weight
.
data
.
clone
(),
requires_grad
=
has_fp16_weights
,
has_fp16_weights
=
has_fp16_weights
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
=
linear_custom
.
cuda
()
if
serialize_before_forward
:
state_dict_8bit
=
linear_custom
.
state_dict
()
x_first
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_first
=
linear_custom
(
x_first
).
float
()
grad_proj
=
torch
.
randn_like
(
fx_first
)
(
fx_first
*
grad_proj
).
mean
().
backward
()
state_dict
=
deepcopy
(
linear_custom
.
state_dict
())
if
not
serialize_before_forward
:
state_dict_8bit
=
linear_custom
.
state_dict
()
with
TemporaryDirectory
()
as
tmpdir
:
state_path_8bit
=
os
.
path
.
join
(
tmpdir
,
"state_8bit.pth"
)
state_path
=
os
.
path
.
join
(
tmpdir
,
"state.pth"
)
torch
.
save
(
linear
.
state_dict
(),
state_path
)
torch
.
save
(
state_dict_8bit
,
state_path_8bit
)
if
not
has_fp16_weights
:
assert
os
.
path
.
getsize
(
state_path_8bit
)
<
0.5
*
os
.
path
.
getsize
(
state_path
)
new_state_dict
=
torch
.
load
(
state_path_8bit
)
new_linear_custom
=
Linear8bitLt
(
linear
.
in_features
,
...
...
@@ -99,13 +118,21 @@ def test_linear_serialization(has_fp16_weights):
has_fp16_weights
=
has_fp16_weights
,
threshold
=
6.0
,
)
linear_custom
.
state
.
force_no_igemmlt
=
True
if
deserialize_before_cuda
:
with
nullcontext
()
if
has_fp16_weights
else
pytest
.
raises
(
RuntimeError
):
new_linear_custom
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
new_linear_custom
=
new_linear_custom
.
cuda
()
new_linear_custom
.
load_state_dict
(
state_dict
,
strict
=
True
)
if
not
deserialize_before_cuda
:
new_linear_custom
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
x_second
=
x
.
clone
().
cuda
().
requires_grad_
(
True
)
fx_second
=
new_linear_custom
(
x_second
).
float
()
(
fx_second
*
grad_proj
).
mean
().
backward
()
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
if
has_fp16_weights
or
not
deserialize_before_cuda
:
assert
torch
.
allclose
(
fx_first
,
fx_second
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_first
.
grad
,
x_second
.
grad
,
atol
=
1e-5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment