Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
d9b1125c
Unverified
Commit
d9b1125c
authored
May 30, 2024
by
Titus
Committed by
GitHub
May 30, 2024
Browse files
Merge pull request #1230 from BenjaminBossan/fix-4bit-getstate
FIX_ Prevent __getstate__ from mutating Params4bit
parents
3c8c18a0
2fb212bd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
1 deletion
+16
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+1
-1
tests/test_linear4bit.py
tests/test_linear4bit.py
+15
-0
No files found.
bitsandbytes/nn/modules.py
View file @
d9b1125c
...
@@ -236,7 +236,7 @@ class Params4bit(torch.nn.Parameter):
...
@@ -236,7 +236,7 @@ class Params4bit(torch.nn.Parameter):
return
self
return
self
def
__getstate__
(
self
):
def
__getstate__
(
self
):
state
=
self
.
__dict__
state
=
self
.
__dict__
.
copy
()
state
[
"data"
]
=
self
.
data
state
[
"data"
]
=
self
.
data
state
[
"requires_grad"
]
=
self
.
requires_grad
state
[
"requires_grad"
]
=
self
.
requires_grad
return
state
return
state
...
...
tests/test_linear4bit.py
View file @
d9b1125c
...
@@ -186,19 +186,30 @@ def test_copy_param():
...
@@ -186,19 +186,30 @@ def test_copy_param():
def
test_deepcopy_param
():
def
test_deepcopy_param
():
tensor
=
torch
.
tensor
([
1.0
,
2.0
,
3.0
,
4.0
])
tensor
=
torch
.
tensor
([
1.0
,
2.0
,
3.0
,
4.0
])
param
=
bnb
.
nn
.
Params4bit
(
data
=
tensor
,
requires_grad
=
False
).
cuda
(
0
)
param
=
bnb
.
nn
.
Params4bit
(
data
=
tensor
,
requires_grad
=
False
).
cuda
(
0
)
dict_keys_before
=
set
(
param
.
__dict__
.
keys
())
copy_param
=
copy
.
deepcopy
(
param
)
copy_param
=
copy
.
deepcopy
(
param
)
dict_keys_after
=
set
(
param
.
__dict__
.
keys
())
dict_keys_copy
=
set
(
copy_param
.
__dict__
.
keys
())
assert
param
.
quant_state
is
not
copy_param
.
quant_state
assert
param
.
quant_state
is
not
copy_param
.
quant_state
assert
param
.
data
.
data_ptr
()
!=
copy_param
.
data
.
data_ptr
()
assert
param
.
data
.
data_ptr
()
!=
copy_param
.
data
.
data_ptr
()
# there was a bug where deepcopy would modify the original object
assert
dict_keys_before
==
dict_keys_after
assert
dict_keys_before
==
dict_keys_copy
def
test_params4bit_real_serialization
():
def
test_params4bit_real_serialization
():
original_tensor
=
torch
.
tensor
([
1.0
,
2.0
,
3.0
,
4.0
],
dtype
=
torch
.
float32
)
original_tensor
=
torch
.
tensor
([
1.0
,
2.0
,
3.0
,
4.0
],
dtype
=
torch
.
float32
)
original_param
=
bnb
.
nn
.
Params4bit
(
data
=
original_tensor
,
quant_type
=
"fp4"
)
original_param
=
bnb
.
nn
.
Params4bit
(
data
=
original_tensor
,
quant_type
=
"fp4"
)
dict_keys_before
=
set
(
original_param
.
__dict__
.
keys
())
original_param
.
cuda
(
0
)
# move to CUDA to trigger quantization
original_param
.
cuda
(
0
)
# move to CUDA to trigger quantization
serialized_param
=
pickle
.
dumps
(
original_param
)
serialized_param
=
pickle
.
dumps
(
original_param
)
deserialized_param
=
pickle
.
loads
(
serialized_param
)
deserialized_param
=
pickle
.
loads
(
serialized_param
)
dict_keys_after
=
set
(
original_param
.
__dict__
.
keys
())
dict_keys_deserialized
=
set
(
deserialized_param
.
__dict__
.
keys
())
assert
torch
.
equal
(
original_param
.
data
,
deserialized_param
.
data
)
assert
torch
.
equal
(
original_param
.
data
,
deserialized_param
.
data
)
assert
original_param
.
requires_grad
==
deserialized_param
.
requires_grad
==
False
assert
original_param
.
requires_grad
==
deserialized_param
.
requires_grad
==
False
...
@@ -206,3 +217,7 @@ def test_params4bit_real_serialization():
...
@@ -206,3 +217,7 @@ def test_params4bit_real_serialization():
assert
original_param
.
blocksize
==
deserialized_param
.
blocksize
assert
original_param
.
blocksize
==
deserialized_param
.
blocksize
assert
original_param
.
compress_statistics
==
deserialized_param
.
compress_statistics
assert
original_param
.
compress_statistics
==
deserialized_param
.
compress_statistics
assert
original_param
.
quant_state
==
deserialized_param
.
quant_state
assert
original_param
.
quant_state
==
deserialized_param
.
quant_state
# there was a bug where deepcopy would modify the original object
assert
dict_keys_before
==
dict_keys_after
assert
dict_keys_before
==
dict_keys_deserialized
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment