Unverified Commit 3c8c18a0 authored by Titus's avatar Titus Committed by GitHub
Browse files

Merge pull request #1231 from BenjaminBossan/fix-8bit-deepcopy

FIX Make Int8Params deepcopy-able
parents c08653b1 ed99b3c1
......@@ -560,13 +560,12 @@ class Int8Params(torch.nn.Parameter):
CB=None,
SCB=None,
):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
if data is None:
data = torch.empty(0)
obj = torch.Tensor._make_subclass(cls, data, requires_grad)
obj.CB, obj.SCB = cls.CB, cls.SCB
obj.CB = CB
obj.SCB = SCB
obj.has_fp16_weights = has_fp16_weights
return obj
def cuda(self, device):
......@@ -585,6 +584,18 @@ class Int8Params(torch.nn.Parameter):
return self
def __deepcopy__(self, memo):
# adjust this if new arguments are added to the constructor
new_instance = type(self).__new__(
type(self),
data=copy.deepcopy(self.data, memo),
requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights,
CB=copy.deepcopy(self.CB, memo),
SCB=copy.deepcopy(self.SCB, memo),
)
return new_instance
@overload
def to(
self: T,
......
from contextlib import nullcontext
import copy
import os
import pickle
from tempfile import TemporaryDirectory
import pytest
......@@ -177,3 +179,59 @@ def test_linear_serialization(
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
assert torch.allclose(fx_first, fx_third, atol=1e-5)
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)
@pytest.fixture
def linear8bit():
linear = torch.nn.Linear(32, 96)
linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(),
requires_grad=False,
has_fp16_weights=False,
)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
return linear_custom
def test_linear8bit_copy_param(linear8bit):
shallow_copy = copy.copy(linear8bit)
assert linear8bit.weight is shallow_copy.weight
assert linear8bit.bias is shallow_copy.bias
assert linear8bit.weight.data.data_ptr() == shallow_copy.weight.data.data_ptr()
def test_linear8bit_deepcopy_param(linear8bit):
deep_copy = copy.deepcopy(linear8bit)
assert linear8bit.weight is not deep_copy.weight
assert linear8bit.bias is not deep_copy.bias
assert linear8bit.weight.data.data_ptr() != deep_copy.weight.data.data_ptr()
assert torch.allclose(linear8bit.weight.data, deep_copy.weight.data)
assert linear8bit.state == deep_copy.state
# check for a bug where SCB and CB were not copied
assert deep_copy.weight.SCB is not None
assert (linear8bit.weight.SCB == deep_copy.weight.SCB).all()
assert deep_copy.weight.CB is not None
assert (linear8bit.weight.CB == deep_copy.weight.CB).all()
def test_linear8bit_serialization(linear8bit):
serialized = pickle.dumps(linear8bit)
deserialized = pickle.loads(serialized)
assert linear8bit.weight.data.data_ptr() != deserialized.weight.data.data_ptr()
assert torch.allclose(linear8bit.weight.data, deserialized.weight.data)
assert linear8bit.bias.data.data_ptr() != deserialized.bias.data.data_ptr()
assert torch.allclose(linear8bit.bias.data, deserialized.bias.data)
assert linear8bit.state == deserialized.state
# check for a bug where SCB and CB were not copied
assert (linear8bit.weight.SCB == deserialized.weight.SCB).all()
assert (linear8bit.weight.CB == deserialized.weight.CB).all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment