Unverified Commit a1c0844b authored by rdyro's avatar rdyro Committed by GitHub
Browse files

adding whole Linear8bitLt/Linear4bit module save/load serialization (#1099)

parent f9eba9c8
...@@ -449,7 +449,9 @@ class Int8Params(torch.nn.Parameter): ...@@ -449,7 +449,9 @@ class Int8Params(torch.nn.Parameter):
cls.SCB = None cls.SCB = None
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad) obj = torch.Tensor._make_subclass(cls, data, requires_grad)
obj.CB, obj.SCB = cls.CB, cls.SCB
return obj
def cuda(self, device): def cuda(self, device):
if self.has_fp16_weights: if self.has_fp16_weights:
......
import copy import copy
from io import BytesIO
import os import os
import pickle import pickle
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
...@@ -16,12 +17,24 @@ storage = { ...@@ -16,12 +17,24 @@ storage = {
"float32": torch.float32, "float32": torch.float32,
} }
def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return buffer
def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer)
buffer.seek(0)
return obj
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE) @pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage): @pytest.mark.parametrize("save_before_forward", TRUE_FALSE)
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward):
original_dtype = torch.float16 original_dtype = torch.float16
compute_dtype = None compute_dtype = None
device = "cuda" device = "cuda"
...@@ -124,6 +137,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -124,6 +137,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert a.dtype == b.dtype assert a.dtype == b.dtype
assert torch.equal(a, b) assert torch.equal(a, b)
if save_before_forward:
bytes_4bit = torch_save_to_buffer(linear_q)
# Forward test # Forward test
x = torch.rand(42, layer_shape[0], device=device) x = torch.rand(42, layer_shape[0], device=device)
a = linear_q(x) a = linear_q(x)
...@@ -136,6 +152,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -136,6 +152,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert torch.equal(a, b) assert torch.equal(a, b)
assert torch.equal(a, c) assert torch.equal(a, c)
if not save_before_forward:
bytes_4bit = torch_save_to_buffer(linear_q)
linear_q3 = torch_load_from_buffer(bytes_4bit)
# Test moving to CPU and back to GPU # Test moving to CPU and back to GPU
linear_q2.to("cpu") linear_q2.to("cpu")
linear_q2.to(device) linear_q2.to(device)
...@@ -144,6 +164,11 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -144,6 +164,11 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert c.device == d.device assert c.device == d.device
assert torch.equal(c, d) assert torch.equal(c, d)
d = linear_q3(x)
assert c.dtype == d.dtype
assert c.device == d.device
assert torch.equal(c, d)
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
state_path_4bit = os.path.join(tmpdir, "state_4bit.pth") state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
......
from contextlib import nullcontext from contextlib import nullcontext
from io import BytesIO
import os import os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
...@@ -65,12 +66,25 @@ def test_linear_no_igemmlt(): ...@@ -65,12 +66,25 @@ def test_linear_no_igemmlt():
assert linear_custom.state.CB is not None assert linear_custom.state.CB is not None
assert linear_custom.state.CxB is None assert linear_custom.state.CxB is None
def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return buffer
def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer)
buffer.seek(0)
return obj
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) @pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) @pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt"))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt): @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt, save_before_forward, load_before_cuda):
linear = torch.nn.Linear(32, 96) linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half) x = torch.randn(3, 32, dtype=torch.half)
...@@ -93,6 +107,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri ...@@ -93,6 +107,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
if serialize_before_forward: if serialize_before_forward:
state_dict_8bit = linear_custom.state_dict() state_dict_8bit = linear_custom.state_dict()
if save_before_forward:
bytes_8bit = torch_save_to_buffer(linear_custom)
x_first = x.clone().cuda().requires_grad_(True) x_first = x.clone().cuda().requires_grad_(True)
fx_first = linear_custom(x_first).float() fx_first = linear_custom(x_first).float()
grad_proj = torch.randn_like(fx_first) grad_proj = torch.randn_like(fx_first)
...@@ -101,6 +118,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri ...@@ -101,6 +118,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
if not serialize_before_forward: if not serialize_before_forward:
state_dict_8bit = linear_custom.state_dict() state_dict_8bit = linear_custom.state_dict()
if not save_before_forward:
bytes_8bit = torch_save_to_buffer(linear_custom)
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth") state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
state_path = os.path.join(tmpdir, "state.pth") state_path = os.path.join(tmpdir, "state.pth")
...@@ -127,16 +147,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri ...@@ -127,16 +147,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
new_linear_custom.load_state_dict(new_state_dict, strict=True) new_linear_custom.load_state_dict(new_state_dict, strict=True)
if load_before_cuda:
new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
new_linear_custom = new_linear_custom.cuda() new_linear_custom = new_linear_custom.cuda()
if not deserialize_before_cuda: if not deserialize_before_cuda:
new_linear_custom.load_state_dict(new_state_dict, strict=True) new_linear_custom.load_state_dict(new_state_dict, strict=True)
if not load_before_cuda:
new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
x_second = x.clone().cuda().requires_grad_(True) x_second = x.clone().cuda().requires_grad_(True)
fx_second = new_linear_custom(x_second).float() fx_second = new_linear_custom(x_second).float()
(fx_second * grad_proj).mean().backward() (fx_second * grad_proj).mean().backward()
x_third = x.clone().cuda().requires_grad_(True)
fx_third = new_linear_custom2(x_third).float()
(fx_third * grad_proj).mean().backward()
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
if has_fp16_weights or not deserialize_before_cuda: if has_fp16_weights or not deserialize_before_cuda:
assert torch.allclose(fx_first, fx_second, atol=1e-5) assert torch.allclose(fx_first, fx_second, atol=1e-5)
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
assert torch.allclose(fx_first, fx_third, atol=1e-5)
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment