Commit 1d541b50 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

[WiP] rework of Q_state save format

parent 48b3e770
......@@ -585,33 +585,54 @@ class QuantState:
unpacks dict of tensors into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
"""
tensor2str = lambda xx: ''.join([chr(x) for x in xx]).strip('.')
quant_state_dict = {k.split('.')[-1] :v for k, v in quant_state_dict.items()}
quant_state_dict = {k.split('.')[-1]:v for k, v in quant_state_dict.items()}
if 'quant_state_dict' in quant_state_dict:
quant_state_dict|= quant_state_dict.pop('quant_state_dict')
if 'nested_absmax' in quant_state_dict:
offset = quant_state_dict['nested_offset']
offset = torch.tensor(float(quant_state_dict['nested_offset'])).to(device)
state2 = cls(
absmax=quant_state_dict['nested_absmax'].to(device),
code=quant_state_dict['nested_code'].to(device),
blocksize=quant_state_dict['nested_blocksize'].item(),
dtype=getattr(torch, tensor2str(quant_state_dict['nested_dtype'])),
blocksize=int(quant_state_dict['nested_blocksize']),
dtype=getattr(torch, quant_state_dict['nested_dtype']),
)
else:
offset, state2 = None, None
quant_state = cls(
absmax=quant_state_dict['absmax'].to(device),
shape=torch.Size(quant_state_dict['shape']),
dtype=getattr(torch, tensor2str(quant_state_dict['dtype'])),
blocksize=quant_state_dict['blocksize'].item(),
absmax=quant_state_dict['absmax'].to(device),
shape=torch.Size(map(int, quant_state_dict['shape'].split('.'))),
dtype=getattr(torch, quant_state_dict['dtype']),
blocksize=int(quant_state_dict['blocksize']),
offset=offset,
state2=state2,
quant_type=tensor2str(quant_state_dict['quant_type']),
quant_type=quant_state_dict['quant_type'],
code=quant_state_dict['code'].to(device),
)
return quant_state
def as_dict(self):
"""dict of tensors and strings to use in serialization via _save_to_state_dict()"""
qs_dict = {
'absmax': self.absmax,
'code': self.code,
'shape': ','.join(map(str, self.shape)),
'dtype': str(self.dtype).strip('torch'),
'blocksize': str(self.blocksize),
'quant_type': self.quant_type,
}
if self.nested:
qs_dict.update({
'nested_absmax': self.state2.absmax,
'nested_code': self.state2.code,
'nested_offset': f"{self.offset.item()}",
'nested_blocksize': str(self.state2.blocksize),
'nested_dtype': str(self.state2.dtype).strip('torch'),
})
return qs_dict
def to(self, device):
# make sure the quantization state is on the right device
self.absmax = self.absmax.to(device)
......
......@@ -224,36 +224,18 @@ class Linear4bit(nn.Linear):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training')
def _update_buffers(self):
def string_to_tensor(s):
"""stores string as ints for serialization. assumes codes fit int16"""
return torch.tensor([ord(x) for x in s], dtype=torch.int16)
if getattr(self.weight, 'quant_state', None) is not None:
weight_quant_state = self.weight.quant_state
self.register_buffer('absmax', weight_quant_state.absmax)
self.register_buffer('shape', torch.tensor(weight_quant_state.shape))
self.register_buffer('dtype', string_to_tensor(str(weight_quant_state.dtype).strip('torch')))
self.register_buffer('blocksize', torch.tensor(weight_quant_state.blocksize))
self.register_buffer('quant_type', string_to_tensor(weight_quant_state.quant_type))
self.register_buffer('code', weight_quant_state.code)
if weight_quant_state.nested:
self.register_buffer('nested_offset', weight_quant_state.offset)
self.register_buffer('nested_absmax', weight_quant_state.state2.absmax)
self.register_buffer('nested_code', weight_quant_state.state2.code)
self.register_buffer('nested_blocksize', torch.tensor(weight_quant_state.state2.blocksize))
self.register_buffer('nested_dtype', string_to_tensor(str(weight_quant_state.state2.dtype).strip('torch')))
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
fill state_dict with components of nf4
TODO: test with other 4-bit Q-types
besides weight and bias,
fill state_dict with components of quant_state
"""
self._update_buffers() # link the quant_state items with _buffers
if getattr(self.weight, "quant_state", None) is not None:
quant_state_dict = self.weight.quant_state.as_dict()
tensor_keys = [k for k, v in quant_state_dict.items() if isinstance(v, torch.Tensor)]
for k in tensor_keys:
destination[prefix + "weight." + k] = quant_state_dict.pop(k) if keep_vars else quant_state_dict.pop(k).detach()
destination[prefix + "weight." + "quant_state_dict"] = quant_state_dict
destination[prefix + "weight." + "quantization_method"] = "bitsandbytes." + quant_state_dict["quant_type"]
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
def forward(self, x: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment