Commit 2fb212bd authored by Benjamin Bossan's avatar Benjamin Bossan
Browse files

FIX Prevent __getstate__ from mutating Params4bit

As discussed internally, use state = self.__dict__.copy(), which is also
what the Python docs recommend.
parent c08653b1
...@@ -236,7 +236,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -236,7 +236,7 @@ class Params4bit(torch.nn.Parameter):
return self return self
def __getstate__(self): def __getstate__(self):
state = self.__dict__ state = self.__dict__.copy()
state["data"] = self.data state["data"] = self.data
state["requires_grad"] = self.requires_grad state["requires_grad"] = self.requires_grad
return state return state
......
...@@ -186,19 +186,30 @@ def test_copy_param(): ...@@ -186,19 +186,30 @@ def test_copy_param():
def test_deepcopy_param(): def test_deepcopy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
dict_keys_before = set(param.__dict__.keys())
copy_param = copy.deepcopy(param) copy_param = copy.deepcopy(param)
dict_keys_after = set(param.__dict__.keys())
dict_keys_copy = set(copy_param.__dict__.keys())
assert param.quant_state is not copy_param.quant_state assert param.quant_state is not copy_param.quant_state
assert param.data.data_ptr() != copy_param.data.data_ptr() assert param.data.data_ptr() != copy_param.data.data_ptr()
# there was a bug where deepcopy would modify the original object
assert dict_keys_before == dict_keys_after
assert dict_keys_before == dict_keys_copy
def test_params4bit_real_serialization(): def test_params4bit_real_serialization():
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4") original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")
dict_keys_before = set(original_param.__dict__.keys())
original_param.cuda(0) # move to CUDA to trigger quantization original_param.cuda(0) # move to CUDA to trigger quantization
serialized_param = pickle.dumps(original_param) serialized_param = pickle.dumps(original_param)
deserialized_param = pickle.loads(serialized_param) deserialized_param = pickle.loads(serialized_param)
dict_keys_after = set(original_param.__dict__.keys())
dict_keys_deserialized = set(deserialized_param.__dict__.keys())
assert torch.equal(original_param.data, deserialized_param.data) assert torch.equal(original_param.data, deserialized_param.data)
assert original_param.requires_grad == deserialized_param.requires_grad == False assert original_param.requires_grad == deserialized_param.requires_grad == False
...@@ -206,3 +217,7 @@ def test_params4bit_real_serialization(): ...@@ -206,3 +217,7 @@ def test_params4bit_real_serialization():
assert original_param.blocksize == deserialized_param.blocksize assert original_param.blocksize == deserialized_param.blocksize
assert original_param.compress_statistics == deserialized_param.compress_statistics assert original_param.compress_statistics == deserialized_param.compress_statistics
assert original_param.quant_state == deserialized_param.quant_state assert original_param.quant_state == deserialized_param.quant_state
# there was a bug where deepcopy would modify the original object
assert dict_keys_before == dict_keys_after
assert dict_keys_before == dict_keys_deserialized
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment