"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "3ac864dc39233e95e7261b89f8b4960adc3f88e3"
Commit ed99b3c1 authored by Benjamin Bossan's avatar Benjamin Bossan
Browse files

FIX Make Int8Params deepcopy-able

This requires to implement the __deepcopy__ method in Int8Params.
Moreover, there was an issue in the Linear8BitLT constructor that would
assign instance attributes to the class, which is now fixed.

Please review carefully that this does not impact existing code.

Tests that I ran:

- pytest tests/test_linear8bitlt.py
- in PEFT: python -m pytest -m "single_gpu_tests and bitsandbytes" tests/test_gpu_examples.py
- in PEFT: python -m pytest -m "single_gpu_tests and bitsandbytes" tests/test_common_gpu.py
- in transformers: RUN_SLOW=1 python -m pytest tests/quantization/bnb -x
parent c08653b1
...@@ -560,13 +560,12 @@ class Int8Params(torch.nn.Parameter): ...@@ -560,13 +560,12 @@ class Int8Params(torch.nn.Parameter):
CB=None, CB=None,
SCB=None, SCB=None,
): ):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
obj = torch.Tensor._make_subclass(cls, data, requires_grad) obj = torch.Tensor._make_subclass(cls, data, requires_grad)
obj.CB, obj.SCB = cls.CB, cls.SCB obj.CB = CB
obj.SCB = SCB
obj.has_fp16_weights = has_fp16_weights
return obj return obj
def cuda(self, device): def cuda(self, device):
...@@ -585,6 +584,18 @@ class Int8Params(torch.nn.Parameter): ...@@ -585,6 +584,18 @@ class Int8Params(torch.nn.Parameter):
return self return self
def __deepcopy__(self, memo):
# adjust this if new arguments are added to the constructor
new_instance = type(self).__new__(
type(self),
data=copy.deepcopy(self.data, memo),
requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights,
CB=copy.deepcopy(self.CB, memo),
SCB=copy.deepcopy(self.SCB, memo),
)
return new_instance
@overload @overload
def to( def to(
self: T, self: T,
......
from contextlib import nullcontext from contextlib import nullcontext
import copy
import os import os
import pickle
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import pytest import pytest
...@@ -177,3 +179,59 @@ def test_linear_serialization( ...@@ -177,3 +179,59 @@ def test_linear_serialization(
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
assert torch.allclose(fx_first, fx_third, atol=1e-5) assert torch.allclose(fx_first, fx_third, atol=1e-5)
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5) assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)
@pytest.fixture
def linear8bit():
linear = torch.nn.Linear(32, 96)
linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(),
requires_grad=False,
has_fp16_weights=False,
)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
return linear_custom
def test_linear8bit_copy_param(linear8bit):
shallow_copy = copy.copy(linear8bit)
assert linear8bit.weight is shallow_copy.weight
assert linear8bit.bias is shallow_copy.bias
assert linear8bit.weight.data.data_ptr() == shallow_copy.weight.data.data_ptr()
def test_linear8bit_deepcopy_param(linear8bit):
deep_copy = copy.deepcopy(linear8bit)
assert linear8bit.weight is not deep_copy.weight
assert linear8bit.bias is not deep_copy.bias
assert linear8bit.weight.data.data_ptr() != deep_copy.weight.data.data_ptr()
assert torch.allclose(linear8bit.weight.data, deep_copy.weight.data)
assert linear8bit.state == deep_copy.state
# check for a bug where SCB and CB were not copied
assert deep_copy.weight.SCB is not None
assert (linear8bit.weight.SCB == deep_copy.weight.SCB).all()
assert deep_copy.weight.CB is not None
assert (linear8bit.weight.CB == deep_copy.weight.CB).all()
def test_linear8bit_serialization(linear8bit):
serialized = pickle.dumps(linear8bit)
deserialized = pickle.loads(serialized)
assert linear8bit.weight.data.data_ptr() != deserialized.weight.data.data_ptr()
assert torch.allclose(linear8bit.weight.data, deserialized.weight.data)
assert linear8bit.bias.data.data_ptr() != deserialized.bias.data.data_ptr()
assert torch.allclose(linear8bit.bias.data, deserialized.bias.data)
assert linear8bit.state == deserialized.state
# check for a bug where SCB and CB were not copied
assert (linear8bit.weight.SCB == deserialized.weight.SCB).all()
assert (linear8bit.weight.CB == deserialized.weight.CB).all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment