Unverified Commit cfd6ac75 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

add deepcopy and copy for Param4bit (#1060)



* fix deepcopy and copy

* add tests

* remove line

* ruff fix

* ruff

* Update tests/test_linear4bit.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* add missing state

* ruff format

* ignore formatting commit for git blame

* Params4bit should be initialized as frozen by default

* add test for serialization round-tripping

* add comparison capability for QuantSate

* add back accidentally remove line

---------
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>
Co-authored-by: default avatarTitus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
parent b0730f4d
...@@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 ...@@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848
# Remove f-prefix from strings that don't use formatting # Remove f-prefix from strings that don't use formatting
7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6 7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6
# format tests/linear_4bit.py
34735ba89de8235ea9da6ef409f814dcea9e2038
\ No newline at end of file
...@@ -706,6 +706,21 @@ class QuantState: ...@@ -706,6 +706,21 @@ class QuantState:
self.state2.absmax = self.state2.absmax.to(device) self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device) self.state2.code = self.state2.code.to(device)
def __eq__(self, other):
if not isinstance(other, QuantState):
return False
return (
torch.allclose(self.absmax, other.absmax, atol=1e-6) and
self.shape == other.shape and
torch.allclose(self.code, other.code, atol=1e-6) and
self.dtype == other.dtype and
self.blocksize == other.blocksize and
self.quant_type == other.quant_type and
(self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and
(self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2)
)
def quantize_blockwise( def quantize_blockwise(
A: Tensor, A: Tensor,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy
from typing import Any, Dict, Optional, TypeVar, Union, overload from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings import warnings
...@@ -191,7 +192,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -191,7 +192,7 @@ class Params4bit(torch.nn.Parameter):
def __new__( def __new__(
cls, cls,
data: Optional[torch.Tensor] = None, data: Optional[torch.Tensor] = None,
requires_grad=True, requires_grad=False, # quantized weights should be frozen by default
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
blocksize: int = 64, blocksize: int = 64,
compress_statistics: bool = True, compress_statistics: bool = True,
...@@ -214,6 +215,37 @@ class Params4bit(torch.nn.Parameter): ...@@ -214,6 +215,37 @@ class Params4bit(torch.nn.Parameter):
self.module = module self.module = module
return self return self
def __getstate__(self):
state = self.__dict__
state["data"] = self.data
state["requires_grad"] = self.requires_grad
return state
def __setstate__(self, state):
self.requires_grad = state["requires_grad"]
self.blocksize = state["blocksize"]
self.compress_statistics = state["compress_statistics"]
self.quant_type = state["quant_type"]
self.quant_state = state["quant_state"]
self.data = state["data"]
self.quant_storage = state["quant_storage"]
self.bnb_quantized = state["bnb_quantized"]
self.module = state["module"]
def __deepcopy__(self,memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
new_instance.quant_state = copy.deepcopy(state["quant_state"])
new_instance.data = copy.deepcopy(state["data"])
return new_instance
def __copy__(self):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance
@classmethod @classmethod
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit": def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device)) self = torch.Tensor._make_subclass(cls, data.to(device))
...@@ -227,8 +259,13 @@ class Params4bit(torch.nn.Parameter): ...@@ -227,8 +259,13 @@ class Params4bit(torch.nn.Parameter):
def _quantize(self, device): def _quantize(self, device):
w = self.data.contiguous().cuda(device) w = self.data.contiguous().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, w_4bit, quant_state = bnb.functional.quantize_4bit(
quant_type=self.quant_type, quant_storage=self.quant_storage) w,
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
)
self.data = w_4bit self.data = w_4bit
self.quant_state = quant_state self.quant_state = quant_state
if self.module is not None: if self.module is not None:
......
import copy
import os import os
import pickle
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import pytest import pytest
...@@ -8,13 +10,14 @@ import bitsandbytes as bnb ...@@ -8,13 +10,14 @@ import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE from tests.helpers import TRUE_FALSE
storage = { storage = {
'uint8': torch.uint8, "uint8": torch.uint8,
'float16': torch.float16, "float16": torch.float16,
'bfloat16': torch.bfloat16, "bfloat16": torch.bfloat16,
'float32': torch.float32 "float32": torch.float32,
} }
@pytest.mark.parametrize("quant_storage", ['uint8', 'float16', 'bfloat16', 'float32'])
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE) @pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
...@@ -24,7 +27,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -24,7 +27,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
device = "cuda" device = "cuda"
layer_shape = (300, 400) layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer linear = torch.nn.Linear(
*layer_shape, dtype=original_dtype, device="cpu"
) # original layer
# Quantizing original layer # Quantizing original layer
linear_q = bnb.nn.Linear4bit( linear_q = bnb.nn.Linear4bit(
...@@ -36,7 +41,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -36,7 +41,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_type=quant_type, quant_type=quant_type,
device="meta", device="meta",
) )
new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False) new_weight = bnb.nn.Params4bit(
data=linear.weight, quant_type=quant_type, requires_grad=False
)
linear_q.weight = new_weight linear_q.weight = new_weight
if bias: if bias:
linear_q.bias = torch.nn.Parameter(linear.bias) linear_q.bias = torch.nn.Parameter(linear.bias)
...@@ -80,7 +87,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -80,7 +87,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_storage=storage[quant_storage], quant_storage=storage[quant_storage],
device="meta", device="meta",
) )
linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage]) linear_qs.weight = bnb.nn.Params4bit(
data=linear.weight,
requires_grad=False,
quant_type=quant_type,
quant_storage=storage[quant_storage],
)
if bias: if bias:
linear_qs.bias = torch.nn.Parameter(linear.bias) linear_qs.bias = torch.nn.Parameter(linear.bias)
linear_qs = linear_qs.to(device) linear_qs = linear_qs.to(device)
...@@ -91,7 +103,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -91,7 +103,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
q0 = a.quant_state q0 = a.quant_state
q1 = b.quant_state q1 = b.quant_state
for attr in ('code', 'dtype', 'blocksize', 'absmax'): for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0, attr), getattr(q1, attr) c, d = getattr(q0, attr), getattr(q1, attr)
if isinstance(c, torch.Tensor): if isinstance(c, torch.Tensor):
assert torch.equal(c, d) assert torch.equal(c, d)
...@@ -99,7 +111,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -99,7 +111,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert c == d, f"{c} != {d}" assert c == d, f"{c} != {d}"
if q0.state2 is not None: if q0.state2 is not None:
for attr in ('code', 'dtype', 'blocksize', 'absmax'): for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0.state2, attr), getattr(q1.state2, attr) c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
if isinstance(c, torch.Tensor): if isinstance(c, torch.Tensor):
assert torch.equal(c, d) assert torch.equal(c, d)
...@@ -125,7 +137,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -125,7 +137,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert torch.equal(a, c) assert torch.equal(a, c)
# Test moving to CPU and back to GPU # Test moving to CPU and back to GPU
linear_q2.to('cpu') linear_q2.to("cpu")
linear_q2.to(device) linear_q2.to(device)
d = linear_qs(x) d = linear_qs(x)
assert c.dtype == d.dtype assert c.dtype == d.dtype
...@@ -139,10 +151,47 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -139,10 +151,47 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
torch.save(linear.state_dict(), state_path) torch.save(linear.state_dict(), state_path)
torch.save(linear_q.state_dict(), state_path_4bit) torch.save(linear_q.state_dict(), state_path_4bit)
size_orig, size_4 = os.path.getsize(state_path), os.path.getsize( size_orig, size_4 = (
state_path_4bit os.path.getsize(state_path),
os.path.getsize(state_path_4bit),
) )
size_ratio = size_4 / size_orig size_ratio = size_4 / size_orig
target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases target_compression = (
0.143 if original_dtype == torch.float32 else 0.29
) # these numbers get lower as weight shape increases
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
assert size_ratio < target_compression, ratio_error_msg assert size_ratio < target_compression, ratio_error_msg
def test_copy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
shallow_copy_param = copy.copy(param)
assert param.quant_state is shallow_copy_param.quant_state
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
def test_deepcopy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
copy_param = copy.deepcopy(param)
assert param.quant_state is not copy_param.quant_state
assert param.data.data_ptr() != copy_param.data.data_ptr()
def test_params4bit_real_serialization():
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")
original_param.cuda(0) # move to CUDA to trigger quantization
serialized_param = pickle.dumps(original_param)
deserialized_param = pickle.loads(serialized_param)
assert torch.equal(original_param.data, deserialized_param.data)
assert original_param.requires_grad == deserialized_param.requires_grad == False
assert original_param.quant_type == deserialized_param.quant_type
assert original_param.blocksize == deserialized_param.blocksize
assert original_param.compress_statistics == deserialized_param.compress_statistics
assert original_param.quant_state == deserialized_param.quant_state
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment