Commit 5bcc1ddc authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

save/load 4bit squashed

parent 61a4a20d
......@@ -578,6 +578,36 @@ class QuantState:
self.state2 = state2
self.nested = state2 is not None
@classmethod
def from_kwargs(cls, kwargs, device):
tensor2str = lambda xx: ''.join([chr(x) for x in xx]).strip('.')
kwargs = {k.split('.')[-1] :v for k, v in kwargs.items()}
if 'nested_absmax' in kwargs:
offset = kwargs['nested_offset']
state2 = cls(
absmax=kwargs['nested_absmax'].to(device),
code=kwargs['nested_code'].to(device),
blocksize=kwargs['nested_blocksize'].item(),
dtype=getattr(torch, tensor2str(kwargs['nested_dtype'])),
)
else:
offset, state2 = None, None
quant_state = cls(
absmax=kwargs['absmax'].to(device),
shape=torch.Size(kwargs['shape']),
dtype=getattr(torch, tensor2str(kwargs['dtype'])),
blocksize=kwargs['blocksize'].item(),
offset=offset,
state2=state2,
quant_type=tensor2str(kwargs['quant_type']),
code=kwargs['code'].to(device),
)
return quant_state
def to(self, device):
# make sure the quantization state is on the right device
self.absmax = self.absmax.to(device)
......
......@@ -10,7 +10,7 @@ import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
import bitsandbytes as bnb
import bitsandbytes.functional
from bitsandbytes.functional import QuantState
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
......@@ -140,6 +140,7 @@ class Embedding(torch.nn.Embedding):
return emb
class Params4bit(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
if data is None:
data = torch.empty(0)
......@@ -151,6 +152,18 @@ class Params4bit(torch.nn.Parameter):
self.quant_state = quant_state
self.data = data
return self
@classmethod
def from_prequantized(cls, quantized_stats, data=None, requires_grad=False, device='cuda', **kwargs):
if data is None:
data = quantized_stats.pop('weight')
self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
self.quant_state = QuantState.from_kwargs(kwargs=quantized_stats, device=device)
self.blocksize = self.quant_state.blocksize
self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type
return self
def cuda(self, device):
w = self.data.contiguous().half().cuda(device)
......@@ -211,6 +224,38 @@ class Linear4bit(nn.Linear):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training')
def _update_buffers(self):
def string_to_tensor(s):
"""stores string as ints for serialization. assumes codes fit int16"""
return torch.tensor([ord(x) for x in s], dtype=torch.int16)
if getattr(self.weight, 'quant_state', None) is not None:
weight_quant_state = self.weight.quant_state
self.register_buffer('absmax', weight_quant_state.absmax)
self.register_buffer('shape', torch.tensor(weight_quant_state.shape))
self.register_buffer('dtype', string_to_tensor(str(weight_quant_state.dtype).strip('torch')))
self.register_buffer('blocksize', torch.tensor(weight_quant_state.blocksize))
self.register_buffer('quant_type', string_to_tensor(weight_quant_state.quant_type))
self.register_buffer('code', weight_quant_state.code)
if weight_quant_state.nested:
self.register_buffer('nested_offset', weight_quant_state.offset)
self.register_buffer('nested_absmax', weight_quant_state.state2.absmax)
self.register_buffer('nested_code', weight_quant_state.state2.code)
self.register_buffer('nested_blocksize', torch.tensor(weight_quant_state.state2.blocksize))
self.register_buffer('nested_dtype', string_to_tensor(str(weight_quant_state.state2.dtype).strip('torch')))
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
fill state_dict with components of nf4
TODO: test with other 4-bit Q-types
"""
self._update_buffers() # link the quant_state items with _buffers
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
......
import os
from contextlib import nullcontext
from itertools import product
from tempfile import TemporaryDirectory
import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.nn.modules import Linear4bit
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize(
"quant_type, compress_statistics, bias",
list(product(["nf4", "fp4"], [False, True], [False, True])),
)
def test_linear4_state_dict(quant_type, compress_statistics, bias):
original_dtype = torch.float16
compute_dtype = None
device = "cuda"
layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype) # original layer
# Quantizing original layer
linear_q = bnb.nn.Linear4bit(
linear.in_features,
linear.out_features,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device=device,
)
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False)
linear_q.weight = new_weight.to(device)
if bias:
linear_q.bias.data = linear.bias.data.to(device)
sd = linear_q.state_dict()
# restoring from state_dict:
sd = linear_q.state_dict()
bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
linear_q2 = bnb.nn.Linear4bit(
linear.in_features,
linear.out_features,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device=device,
)
linear_q2.weight = weight2.to(device)
if bias:
linear_q2.bias.data = bias_data2
# matching
a, b = linear_q.weight, linear_q2.weight
assert a.device == b.device
assert a.dtype == b.dtype
assert torch.equal(a, b)
q0 = a.quant_state
q1 = b.quant_state
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
c, d = getattr(q0, attr), getattr(q1, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
else:
assert c == d, f"{c} != {d}"
if q0.state2 is not None:
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
else:
assert c == d, f"{c} != {d}"
if bias:
a, b = linear_q.bias, linear_q2.bias
assert a.device == b.device
assert a.dtype == b.dtype
assert torch.equal(a, b)
# Forward test
x = torch.rand(42, linear_q.shape[-1], device=device)
a = linear_q(x)
b = linear_q2(x)
assert a.device == b.device
assert a.dtype == b.dtype
assert torch.equal(a, b)
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
with TemporaryDirectory() as tmpdir:
state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
state_path = os.path.join(tmpdir, "state.pth")
torch.save(linear.state_dict(), state_path)
torch.save(linear_q.state_dict(), state_path_4bit)
size_orig, size_4 = os.path.getsize(state_path), os.path.getsize(
state_path_4bit
)
size_ratio = size_4 / size_orig
target_compression = 0.143 if original_dtype == torch.float32 else 0.285
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
assert size_ratio < target_compression, ratio_error_msg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment