Commit 781fcd5b authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

partially reverted 76b40a5c

parent c6d0a847
...@@ -571,9 +571,9 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n ...@@ -571,9 +571,9 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
class QuantState: class QuantState:
"""container for quantization state components to work with Params4bit and similar clases""" """container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types = ('fp4', 'nf4') valid_quant_types = ('fp4', 'nf4')
valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types] valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type',
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax self.absmax = absmax
...@@ -611,16 +611,19 @@ class QuantState: ...@@ -611,16 +611,19 @@ class QuantState:
""" """
# unpacking tensor with non-tensor components # unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if k in cls.valid_qs_type_keys and isinstance(v, torch.Tensor)] qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
if not len(qs_key) and 'quant_type' not in qs_dict: if not len(qs_key) and 'quant_type' not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither") raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1: elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
raise ValueError(f"There should be exaclly one quant_state item with key from {cls.valid_qs_type_keys}. Detected {len(qs_key)} such items") raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.")
# unpacking minor and non-tensor quant state items if necessary # unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1: if len(qs_key) == 1:
qs_key = qs_key[0] qs_key = qs_key[0]
qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key)) qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(qs_key)))
qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
if 'nested_absmax' in qs_dict: if 'nested_absmax' in qs_dict:
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
...@@ -654,7 +657,7 @@ class QuantState: ...@@ -654,7 +657,7 @@ class QuantState:
'quant_type': self.quant_type, 'quant_type': self.quant_type,
'absmax': self.absmax, 'absmax': self.absmax,
'blocksize': self.blocksize, 'blocksize': self.blocksize,
'quant_map': self.code, 'quant_map': self.code,
'dtype': str(self.dtype).strip('torch.'), 'dtype': str(self.dtype).strip('torch.'),
'shape': tuple(self.shape) if self.nested else None, 'shape': tuple(self.shape) if self.nested else None,
} }
...@@ -677,6 +680,7 @@ class QuantState: ...@@ -677,6 +680,7 @@ class QuantState:
def to(self, device): def to(self, device):
# make sure the quantization state is on the right device # make sure the quantization state is on the right device
self.absmax = self.absmax.to(device) self.absmax = self.absmax.to(device)
self.offset = self.offset.to(device)
if self.nested: if self.nested:
self.offset = self.offset.to(device) self.offset = self.offset.to(device)
self.state2.absmax = self.state2.absmax.to(device) self.state2.absmax = self.state2.absmax.to(device)
......
...@@ -155,28 +155,38 @@ class Params4bit(torch.nn.Parameter): ...@@ -155,28 +155,38 @@ class Params4bit(torch.nn.Parameter):
return self return self
@classmethod @classmethod
def from_state_dict(cls, state_dict, prefix="", requires_grad=False): def from_prequantized(cls, data, quantized_stats, requires_grad=False, device='cuda', **kwargs):
data = state_dict.pop(prefix.rstrip('.')) self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
# extracting components for QuantState from state_dict self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
qs_dict = {} self.blocksize = self.quant_state.blocksize
for k, v in state_dict.items(): self.compress_statistics = self.quant_state.nested
if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys: self.quant_type = self.quant_state.quant_type
qs_dict[k] = v return self
state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()} # @classmethod
# def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
if data.device.type != "cuda": # data = state_dict.pop(prefix.rstrip('.'))
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
# # extracting components for QuantState from state_dict
cls.requires_grad = requires_grad # qs_dict = {}
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device) # for k, v in state_dict.items():
cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state # if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested # qs_dict[k] = v
cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state # state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
# qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
return self, state_dict # if data.device.type != "cuda":
# raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
# cls.requires_grad = requires_grad
# cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
# cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
# cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
# cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
# self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
# return self, state_dict
def cuda(self, device): def cuda(self, device):
w = self.data.contiguous().half().cuda(device) w = self.data.contiguous().half().cuda(device)
...@@ -251,17 +261,17 @@ class Linear4bit(nn.Linear): ...@@ -251,17 +261,17 @@ class Linear4bit(nn.Linear):
for k, v in self.weight.quant_state.as_dict(packed=True).items(): for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach() destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, # def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs): # missing_keys, unexpected_keys, error_msgs):
# Note: super()._load_from_state_dict() is not called here intentionally. # # Note: super()._load_from_state_dict() is not called here intentionally.
if self.bias is not None: # if self.bias is not None:
bias_data = state_dict.pop(prefix + "bias", None) # bias_data = state_dict.pop(prefix + "bias", None)
self.bias.data = bias_data.to(self.bias.data.device) # self.bias.data = bias_data.to(self.bias.data.device)
self.weight, state_dict = bnb.nn.Params4bit.from_state_dict( # self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
state_dict, prefix=prefix + "weight" + ".", requires_grad=False # state_dict, prefix=prefix + "weight" + ".", requires_grad=False
) # )
unexpected_keys.extend(state_dict.keys()) # unexpected_keys.extend(state_dict.keys())
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually # weights are cast automatically as Int8Params, but the bias has to be cast manually
......
...@@ -7,8 +7,6 @@ import pytest ...@@ -7,8 +7,6 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.nn.modules import Linear4bit
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
...@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
# saving to state_dict: # saving to state_dict:
sd = linear_q.state_dict() sd = linear_q.state_dict()
# restoring from state_dict:
bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
# creating new layer with same params: # creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit( linear_q2 = bnb.nn.Linear4bit(
linear.in_features, linear.in_features,
...@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device=device, # TODO create on meta device to save loading time device=device, # TODO create on meta device to save loading time
) )
# loading weights from state_dict: # loading weights from state_dict:
linear_q2.load_state_dict(sd) linear_q2.weight = weight2.to(device)
if bias:
linear_q2.bias = torch.nn.Parameter(bias_data2)
# MATCHING # MATCHING
a, b = linear_q.weight, linear_q2.weight a, b = linear_q.weight, linear_q2.weight
...@@ -61,7 +64,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -61,7 +64,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
assert a.device == b.device assert a.device == b.device
assert a.dtype == b.dtype assert a.dtype == b.dtype
assert torch.equal(a, b) assert torch.equal(a, b)
q0 = a.quant_state q0 = a.quant_state
q1 = b.quant_state q1 = b.quant_state
for attr in ('code', 'dtype', 'blocksize', 'absmax'): for attr in ('code', 'dtype', 'blocksize', 'absmax'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment