Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
3c8c18a0
Unverified
Commit
3c8c18a0
authored
May 30, 2024
by
Titus
Committed by
GitHub
May 30, 2024
Browse files
Merge pull request #1231 from BenjaminBossan/fix-8bit-deepcopy
FIX Make Int8Params deepcopy-able
parents
c08653b1
ed99b3c1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
4 deletions
+73
-4
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+15
-4
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+58
-0
No files found.
bitsandbytes/nn/modules.py
View file @
3c8c18a0
...
...
@@ -560,13 +560,12 @@ class Int8Params(torch.nn.Parameter):
CB
=
None
,
SCB
=
None
,
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
cls
.
SCB
=
None
if
data
is
None
:
data
=
torch
.
empty
(
0
)
obj
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
obj
.
CB
,
obj
.
SCB
=
cls
.
CB
,
cls
.
SCB
obj
.
CB
=
CB
obj
.
SCB
=
SCB
obj
.
has_fp16_weights
=
has_fp16_weights
return
obj
def
cuda
(
self
,
device
):
...
...
@@ -585,6 +584,18 @@ class Int8Params(torch.nn.Parameter):
return
self
def
__deepcopy__
(
self
,
memo
):
# adjust this if new arguments are added to the constructor
new_instance
=
type
(
self
).
__new__
(
type
(
self
),
data
=
copy
.
deepcopy
(
self
.
data
,
memo
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
,
CB
=
copy
.
deepcopy
(
self
.
CB
,
memo
),
SCB
=
copy
.
deepcopy
(
self
.
SCB
,
memo
),
)
return
new_instance
@
overload
def
to
(
self
:
T
,
...
...
tests/test_linear8bitlt.py
View file @
3c8c18a0
from
contextlib
import
nullcontext
import
copy
import
os
import
pickle
from
tempfile
import
TemporaryDirectory
import
pytest
...
...
@@ -177,3 +179,59 @@ def test_linear_serialization(
assert
torch
.
allclose
(
x_first
.
grad
,
x_second
.
grad
,
atol
=
1e-5
)
assert
torch
.
allclose
(
fx_first
,
fx_third
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_first
.
grad
,
x_third
.
grad
,
atol
=
1e-5
)
@
pytest
.
fixture
def
linear8bit
():
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
linear_custom
=
Linear8bitLt
(
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
has_fp16_weights
=
False
,
threshold
=
6.0
,
)
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
,
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
=
linear_custom
.
cuda
()
return
linear_custom
def
test_linear8bit_copy_param
(
linear8bit
):
shallow_copy
=
copy
.
copy
(
linear8bit
)
assert
linear8bit
.
weight
is
shallow_copy
.
weight
assert
linear8bit
.
bias
is
shallow_copy
.
bias
assert
linear8bit
.
weight
.
data
.
data_ptr
()
==
shallow_copy
.
weight
.
data
.
data_ptr
()
def
test_linear8bit_deepcopy_param
(
linear8bit
):
deep_copy
=
copy
.
deepcopy
(
linear8bit
)
assert
linear8bit
.
weight
is
not
deep_copy
.
weight
assert
linear8bit
.
bias
is
not
deep_copy
.
bias
assert
linear8bit
.
weight
.
data
.
data_ptr
()
!=
deep_copy
.
weight
.
data
.
data_ptr
()
assert
torch
.
allclose
(
linear8bit
.
weight
.
data
,
deep_copy
.
weight
.
data
)
assert
linear8bit
.
state
==
deep_copy
.
state
# check for a bug where SCB and CB were not copied
assert
deep_copy
.
weight
.
SCB
is
not
None
assert
(
linear8bit
.
weight
.
SCB
==
deep_copy
.
weight
.
SCB
).
all
()
assert
deep_copy
.
weight
.
CB
is
not
None
assert
(
linear8bit
.
weight
.
CB
==
deep_copy
.
weight
.
CB
).
all
()
def
test_linear8bit_serialization
(
linear8bit
):
serialized
=
pickle
.
dumps
(
linear8bit
)
deserialized
=
pickle
.
loads
(
serialized
)
assert
linear8bit
.
weight
.
data
.
data_ptr
()
!=
deserialized
.
weight
.
data
.
data_ptr
()
assert
torch
.
allclose
(
linear8bit
.
weight
.
data
,
deserialized
.
weight
.
data
)
assert
linear8bit
.
bias
.
data
.
data_ptr
()
!=
deserialized
.
bias
.
data
.
data_ptr
()
assert
torch
.
allclose
(
linear8bit
.
bias
.
data
,
deserialized
.
bias
.
data
)
assert
linear8bit
.
state
==
deserialized
.
state
# check for a bug where SCB and CB were not copied
assert
(
linear8bit
.
weight
.
SCB
==
deserialized
.
weight
.
SCB
).
all
()
assert
(
linear8bit
.
weight
.
CB
==
deserialized
.
weight
.
CB
).
all
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment