Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
3c8c18a0
Unverified
Commit
3c8c18a0
authored
May 30, 2024
by
Titus
Committed by
GitHub
May 30, 2024
Browse files
Merge pull request #1231 from BenjaminBossan/fix-8bit-deepcopy
FIX Make Int8Params deepcopy-able
parents
c08653b1
ed99b3c1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
4 deletions
+73
-4
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+15
-4
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+58
-0
No files found.
bitsandbytes/nn/modules.py
View file @
3c8c18a0
...
...
@@ -560,13 +560,12 @@ class Int8Params(torch.nn.Parameter):
CB
=
None
,
SCB
=
None
,
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
cls
.
SCB
=
None
if
data
is
None
:
data
=
torch
.
empty
(
0
)
obj
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
obj
.
CB
,
obj
.
SCB
=
cls
.
CB
,
cls
.
SCB
obj
.
CB
=
CB
obj
.
SCB
=
SCB
obj
.
has_fp16_weights
=
has_fp16_weights
return
obj
def
cuda
(
self
,
device
):
...
...
@@ -585,6 +584,18 @@ class Int8Params(torch.nn.Parameter):
return
self
def
__deepcopy__
(
self
,
memo
):
# adjust this if new arguments are added to the constructor
new_instance
=
type
(
self
).
__new__
(
type
(
self
),
data
=
copy
.
deepcopy
(
self
.
data
,
memo
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
,
CB
=
copy
.
deepcopy
(
self
.
CB
,
memo
),
SCB
=
copy
.
deepcopy
(
self
.
SCB
,
memo
),
)
return
new_instance
@
overload
def
to
(
self
:
T
,
...
...
tests/test_linear8bitlt.py
View file @
3c8c18a0
from
contextlib
import
nullcontext
import
copy
import
os
import
pickle
from
tempfile
import
TemporaryDirectory
import
pytest
...
...
@@ -177,3 +179,59 @@ def test_linear_serialization(
assert
torch
.
allclose
(
x_first
.
grad
,
x_second
.
grad
,
atol
=
1e-5
)
assert
torch
.
allclose
(
fx_first
,
fx_third
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_first
.
grad
,
x_third
.
grad
,
atol
=
1e-5
)
@
pytest
.
fixture
def
linear8bit
():
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
linear_custom
=
Linear8bitLt
(
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
has_fp16_weights
=
False
,
threshold
=
6.0
,
)
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
,
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
=
linear_custom
.
cuda
()
return
linear_custom
def
test_linear8bit_copy_param
(
linear8bit
):
shallow_copy
=
copy
.
copy
(
linear8bit
)
assert
linear8bit
.
weight
is
shallow_copy
.
weight
assert
linear8bit
.
bias
is
shallow_copy
.
bias
assert
linear8bit
.
weight
.
data
.
data_ptr
()
==
shallow_copy
.
weight
.
data
.
data_ptr
()
def
test_linear8bit_deepcopy_param
(
linear8bit
):
deep_copy
=
copy
.
deepcopy
(
linear8bit
)
assert
linear8bit
.
weight
is
not
deep_copy
.
weight
assert
linear8bit
.
bias
is
not
deep_copy
.
bias
assert
linear8bit
.
weight
.
data
.
data_ptr
()
!=
deep_copy
.
weight
.
data
.
data_ptr
()
assert
torch
.
allclose
(
linear8bit
.
weight
.
data
,
deep_copy
.
weight
.
data
)
assert
linear8bit
.
state
==
deep_copy
.
state
# check for a bug where SCB and CB were not copied
assert
deep_copy
.
weight
.
SCB
is
not
None
assert
(
linear8bit
.
weight
.
SCB
==
deep_copy
.
weight
.
SCB
).
all
()
assert
deep_copy
.
weight
.
CB
is
not
None
assert
(
linear8bit
.
weight
.
CB
==
deep_copy
.
weight
.
CB
).
all
()
def
test_linear8bit_serialization
(
linear8bit
):
serialized
=
pickle
.
dumps
(
linear8bit
)
deserialized
=
pickle
.
loads
(
serialized
)
assert
linear8bit
.
weight
.
data
.
data_ptr
()
!=
deserialized
.
weight
.
data
.
data_ptr
()
assert
torch
.
allclose
(
linear8bit
.
weight
.
data
,
deserialized
.
weight
.
data
)
assert
linear8bit
.
bias
.
data
.
data_ptr
()
!=
deserialized
.
bias
.
data
.
data_ptr
()
assert
torch
.
allclose
(
linear8bit
.
bias
.
data
,
deserialized
.
bias
.
data
)
assert
linear8bit
.
state
==
deserialized
.
state
# check for a bug where SCB and CB were not copied
assert
(
linear8bit
.
weight
.
SCB
==
deserialized
.
weight
.
SCB
).
all
()
assert
(
linear8bit
.
weight
.
CB
==
deserialized
.
weight
.
CB
).
all
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment