Commit 6cf0f05d authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

rework of non-tensor qs items storage

parent 6a934d4f
......@@ -13,8 +13,9 @@ from scipy.stats import norm
import numpy as np
from functools import reduce # Required in Python 3
from typing import Tuple
from typing import Tuple, Any
from torch import Tensor
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
from .cextension import COMPILED_WITH_CUDA, lib
......@@ -580,58 +581,77 @@ class QuantState:
self.nested = state2 is not None
@classmethod
def from_dict(cls, quant_state_dict: dict[str, torch.Tensor], device: torch.device) -> 'QuantState':
def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState':
"""
unpacks dict of tensors into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
quant_state_dict may contain item with non-tensor components with key like
`...weight.quant_state.bitsandbytes__[nf4/fp4]`
it is detected with key strored in qs_key, and then unpacked
"""
quant_state_dict = {k.split('.')[-1]:v for k, v in quant_state_dict.items()}
if 'quant_state_dict' in quant_state_dict:
quant_state_dict|= quant_state_dict.pop('quant_state_dict')
# unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
if len(qs_key) == 1:
qs_key = qs_key[0]
assert 'bitsandbytes__nf4' in qs_key or 'bitsandbytes__fp4' in qs_key, \
f"invalid qs_key value {qs_key}"
qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key))
qs_dict = {k.split('.')[-1]:v for k, v in qs_dict.items()} # strip prefixes
if 'nested_absmax' in quant_state_dict:
offset = torch.tensor(float(quant_state_dict['nested_offset'])).to(device)
if 'nested_absmax' in qs_dict:
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
state2 = cls(
absmax=quant_state_dict['nested_absmax'].to(device),
code=quant_state_dict['nested_code'].to(device),
blocksize=int(quant_state_dict['nested_blocksize']),
dtype=getattr(torch, quant_state_dict['nested_dtype']),
absmax=qs_dict['nested_absmax'].to(device),
code=qs_dict['nested_code'].to(device),
blocksize=qs_dict['nested_blocksize'],
dtype=getattr(torch, qs_dict['nested_dtype']),
)
else:
offset, state2 = None, None
quant_state = cls(
absmax=quant_state_dict['absmax'].to(device),
shape=torch.Size(map(int, quant_state_dict['shape'].split('.'))),
dtype=getattr(torch, quant_state_dict['dtype']),
blocksize=int(quant_state_dict['blocksize']),
absmax=qs_dict['absmax'].to(device),
shape=torch.Size(qs_dict['shape']),
dtype=getattr(torch, qs_dict['dtype']),
blocksize=qs_dict['blocksize'],
offset=offset,
state2=state2,
quant_type=quant_state_dict['quant_type'],
code=quant_state_dict['code'].to(device),
quant_type=qs_dict['quant_type'],
code=qs_dict['code'].to(device),
)
return quant_state
def as_dict(self):
"""dict of tensors and strings to use in serialization via _save_to_state_dict()"""
def as_dict(self, packed=False):
"""
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
param: packed -- returns dict[str, torch.Tensor] for state_dict
"""
qs_dict = {
'absmax': self.absmax,
'code': self.code,
'shape': ','.join(map(str, self.shape)),
'dtype': str(self.dtype).strip('torch'),
'blocksize': str(self.blocksize),
'shape': tuple(self.shape),
'dtype': str(self.dtype).strip('torch.'),
'blocksize': self.blocksize,
'quant_type': self.quant_type,
}
if self.nested:
qs_dict.update({
'nested_absmax': self.state2.absmax,
'nested_code': self.state2.code,
'nested_offset': f"{self.offset.item()}",
'nested_blocksize': str(self.state2.blocksize),
'nested_dtype': str(self.state2.dtype).strip('torch'),
'nested_offset': self.offset.item(),
'nested_blocksize': self.state2.blocksize,
'nested_dtype': str(self.state2.dtype).strip('torch.'),
})
return qs_dict
if not packed:
return qs_dict
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
return qs_packed_dict
def to(self, device):
# make sure the quantization state is on the right device
......
......@@ -159,7 +159,7 @@ class Params4bit(torch.nn.Parameter):
data = quantized_stats.pop('weight')
self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
self.quant_state = QuantState.from_dict(quant_state_dict=quantized_stats, device=device)
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
self.blocksize = self.quant_state.blocksize
self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type
......@@ -226,18 +226,14 @@ class Linear4bit(nn.Linear):
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
besides weight and bias,
fill state_dict with components of quant_state
save weight and bias,
then fill state_dict with components of quant_state
"""
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
if getattr(self.weight, "quant_state", None) is not None:
quant_state_dict = self.weight.quant_state.as_dict()
tensor_keys = [k for k, v in quant_state_dict.items() if isinstance(v, torch.Tensor)]
for k in tensor_keys:
destination[prefix + "weight." + k] = quant_state_dict.pop(k) if keep_vars else quant_state_dict.pop(k).detach()
destination[prefix + "weight." + "quant_state_dict"] = quant_state_dict
destination[prefix + "weight." + "quantization_method"] = "bitsandbytes." + quant_state_dict["quant_type"]
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
......
import json
import shlex
import subprocess
import torch
......@@ -158,3 +159,36 @@ def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_wei
if func is not None: func(module)
return model
def pack_dict_to_tensor(source_dict):
"""
Pack a dictionary into a torch tensor for storing quant_state items in state_dict.
Parameters:
- source_dict: The dictionary to be packed.
Returns:
A torch tensor containing the packed data.
"""
json_str = json.dumps(source_dict)
json_bytes = json_str.encode('utf-8')
tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8)
return tensor_data
def unpack_tensor_to_dict(tensor_data):
"""
Unpack a torch tensor into a Python dictionary.
Parameters:
- tensor_data: The torch tensor containing the packed data.
Returns:
A Python dictionary containing the unpacked data.
"""
json_bytes = bytes(tensor_data.numpy())
json_str = json_bytes.decode('utf-8')
unpacked_dict = json.loads(json_str)
return unpacked_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment