"vscode:/vscode.git/clone" did not exist on "97f199569620537c3d0d8821c3e23d87c81d7bd8"
Unverified Commit cfd6ac75 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

add deepcopy and copy for Param4bit (#1060)



* fix deepcopy and copy

* add tests

* remove line

* ruff fix

* ruff

* Update tests/test_linear4bit.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* add missing state

* ruff format

* ignore formatting commit for git blame

* Params4bit should be initialized as frozen by default

* add test for serialization round-tripping

* add comparison capability for QuantSate

* add back accidentally remove line

---------
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>
Co-authored-by: default avatarTitus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
parent b0730f4d
......@@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848
# Remove f-prefix from strings that don't use formatting
7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6
# format tests/linear_4bit.py
34735ba89de8235ea9da6ef409f814dcea9e2038
\ No newline at end of file
......@@ -706,6 +706,21 @@ class QuantState:
self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device)
def __eq__(self, other):
if not isinstance(other, QuantState):
return False
return (
torch.allclose(self.absmax, other.absmax, atol=1e-6) and
self.shape == other.shape and
torch.allclose(self.code, other.code, atol=1e-6) and
self.dtype == other.dtype and
self.blocksize == other.blocksize and
self.quant_type == other.quant_type and
(self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and
(self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2)
)
def quantize_blockwise(
A: Tensor,
......
......@@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings
......@@ -191,7 +192,7 @@ class Params4bit(torch.nn.Parameter):
def __new__(
cls,
data: Optional[torch.Tensor] = None,
requires_grad=True,
requires_grad=False, # quantized weights should be frozen by default
quant_state: Optional[QuantState] = None,
blocksize: int = 64,
compress_statistics: bool = True,
......@@ -214,6 +215,37 @@ class Params4bit(torch.nn.Parameter):
self.module = module
return self
def __getstate__(self):
state = self.__dict__
state["data"] = self.data
state["requires_grad"] = self.requires_grad
return state
def __setstate__(self, state):
self.requires_grad = state["requires_grad"]
self.blocksize = state["blocksize"]
self.compress_statistics = state["compress_statistics"]
self.quant_type = state["quant_type"]
self.quant_state = state["quant_state"]
self.data = state["data"]
self.quant_storage = state["quant_storage"]
self.bnb_quantized = state["bnb_quantized"]
self.module = state["module"]
def __deepcopy__(self,memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
new_instance.quant_state = copy.deepcopy(state["quant_state"])
new_instance.data = copy.deepcopy(state["data"])
return new_instance
def __copy__(self):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance
@classmethod
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device))
......@@ -227,8 +259,13 @@ class Params4bit(torch.nn.Parameter):
def _quantize(self, device):
w = self.data.contiguous().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type, quant_storage=self.quant_storage)
w_4bit, quant_state = bnb.functional.quantize_4bit(
w,
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
)
self.data = w_4bit
self.quant_state = quant_state
if self.module is not None:
......
import copy
import os
import pickle
from tempfile import TemporaryDirectory
import pytest
......@@ -8,13 +10,14 @@ import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE
storage = {
'uint8': torch.uint8,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
'float32': torch.float32
"uint8": torch.uint8,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
@pytest.mark.parametrize("quant_storage", ['uint8', 'float16', 'bfloat16', 'float32'])
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
......@@ -24,7 +27,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
device = "cuda"
layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
linear = torch.nn.Linear(
*layer_shape, dtype=original_dtype, device="cpu"
) # original layer
# Quantizing original layer
linear_q = bnb.nn.Linear4bit(
......@@ -36,7 +41,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_type=quant_type,
device="meta",
)
new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
new_weight = bnb.nn.Params4bit(
data=linear.weight, quant_type=quant_type, requires_grad=False
)
linear_q.weight = new_weight
if bias:
linear_q.bias = torch.nn.Parameter(linear.bias)
......@@ -80,7 +87,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_storage=storage[quant_storage],
device="meta",
)
linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage])
linear_qs.weight = bnb.nn.Params4bit(
data=linear.weight,
requires_grad=False,
quant_type=quant_type,
quant_storage=storage[quant_storage],
)
if bias:
linear_qs.bias = torch.nn.Parameter(linear.bias)
linear_qs = linear_qs.to(device)
......@@ -91,7 +103,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
q0 = a.quant_state
q1 = b.quant_state
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0, attr), getattr(q1, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
......@@ -99,7 +111,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert c == d, f"{c} != {d}"
if q0.state2 is not None:
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
......@@ -125,7 +137,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert torch.equal(a, c)
# Test moving to CPU and back to GPU
linear_q2.to('cpu')
linear_q2.to("cpu")
linear_q2.to(device)
d = linear_qs(x)
assert c.dtype == d.dtype
......@@ -139,10 +151,47 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
torch.save(linear.state_dict(), state_path)
torch.save(linear_q.state_dict(), state_path_4bit)
size_orig, size_4 = os.path.getsize(state_path), os.path.getsize(
state_path_4bit
size_orig, size_4 = (
os.path.getsize(state_path),
os.path.getsize(state_path_4bit),
)
size_ratio = size_4 / size_orig
target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases
target_compression = (
0.143 if original_dtype == torch.float32 else 0.29
) # these numbers get lower as weight shape increases
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
assert size_ratio < target_compression, ratio_error_msg
def test_copy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
shallow_copy_param = copy.copy(param)
assert param.quant_state is shallow_copy_param.quant_state
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
def test_deepcopy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
copy_param = copy.deepcopy(param)
assert param.quant_state is not copy_param.quant_state
assert param.data.data_ptr() != copy_param.data.data_ptr()
def test_params4bit_real_serialization():
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")
original_param.cuda(0) # move to CUDA to trigger quantization
serialized_param = pickle.dumps(original_param)
deserialized_param = pickle.loads(serialized_param)
assert torch.equal(original_param.data, deserialized_param.data)
assert original_param.requires_grad == deserialized_param.requires_grad == False
assert original_param.quant_type == deserialized_param.quant_type
assert original_param.blocksize == deserialized_param.blocksize
assert original_param.compress_statistics == deserialized_param.compress_statistics
assert original_param.quant_state == deserialized_param.quant_state
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment