Commit 1d541b50 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

[WiP] rework of Q_state save format

parent 48b3e770
...@@ -585,33 +585,54 @@ class QuantState: ...@@ -585,33 +585,54 @@ class QuantState:
unpacks dict of tensors into QuantState unpacks dict of tensors into QuantState
where necessary, convert into strings, torch.dtype, ints, etc. where necessary, convert into strings, torch.dtype, ints, etc.
""" """
tensor2str = lambda xx: ''.join([chr(x) for x in xx]).strip('.')
quant_state_dict = {k.split('.')[-1] :v for k, v in quant_state_dict.items()} quant_state_dict = {k.split('.')[-1]:v for k, v in quant_state_dict.items()}
if 'quant_state_dict' in quant_state_dict:
quant_state_dict|= quant_state_dict.pop('quant_state_dict')
if 'nested_absmax' in quant_state_dict: if 'nested_absmax' in quant_state_dict:
offset = quant_state_dict['nested_offset'] offset = torch.tensor(float(quant_state_dict['nested_offset'])).to(device)
state2 = cls( state2 = cls(
absmax=quant_state_dict['nested_absmax'].to(device), absmax=quant_state_dict['nested_absmax'].to(device),
code=quant_state_dict['nested_code'].to(device), code=quant_state_dict['nested_code'].to(device),
blocksize=quant_state_dict['nested_blocksize'].item(), blocksize=int(quant_state_dict['nested_blocksize']),
dtype=getattr(torch, tensor2str(quant_state_dict['nested_dtype'])), dtype=getattr(torch, quant_state_dict['nested_dtype']),
) )
else: else:
offset, state2 = None, None offset, state2 = None, None
quant_state = cls( quant_state = cls(
absmax=quant_state_dict['absmax'].to(device), absmax=quant_state_dict['absmax'].to(device),
shape=torch.Size(quant_state_dict['shape']), shape=torch.Size(map(int, quant_state_dict['shape'].split('.'))),
dtype=getattr(torch, tensor2str(quant_state_dict['dtype'])), dtype=getattr(torch, quant_state_dict['dtype']),
blocksize=quant_state_dict['blocksize'].item(), blocksize=int(quant_state_dict['blocksize']),
offset=offset, offset=offset,
state2=state2, state2=state2,
quant_type=tensor2str(quant_state_dict['quant_type']), quant_type=quant_state_dict['quant_type'],
code=quant_state_dict['code'].to(device), code=quant_state_dict['code'].to(device),
) )
return quant_state return quant_state
def as_dict(self):
"""dict of tensors and strings to use in serialization via _save_to_state_dict()"""
qs_dict = {
'absmax': self.absmax,
'code': self.code,
'shape': ','.join(map(str, self.shape)),
'dtype': str(self.dtype).strip('torch'),
'blocksize': str(self.blocksize),
'quant_type': self.quant_type,
}
if self.nested:
qs_dict.update({
'nested_absmax': self.state2.absmax,
'nested_code': self.state2.code,
'nested_offset': f"{self.offset.item()}",
'nested_blocksize': str(self.state2.blocksize),
'nested_dtype': str(self.state2.dtype).strip('torch'),
})
return qs_dict
def to(self, device): def to(self, device):
# make sure the quantization state is on the right device # make sure the quantization state is on the right device
self.absmax = self.absmax.to(device) self.absmax = self.absmax.to(device)
......
...@@ -224,36 +224,18 @@ class Linear4bit(nn.Linear): ...@@ -224,36 +224,18 @@ class Linear4bit(nn.Linear):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.') warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training') warnings.filterwarnings('ignore', message='.*inference or training')
def _update_buffers(self):
def string_to_tensor(s):
"""stores string as ints for serialization. assumes codes fit int16"""
return torch.tensor([ord(x) for x in s], dtype=torch.int16)
if getattr(self.weight, 'quant_state', None) is not None:
weight_quant_state = self.weight.quant_state
self.register_buffer('absmax', weight_quant_state.absmax)
self.register_buffer('shape', torch.tensor(weight_quant_state.shape))
self.register_buffer('dtype', string_to_tensor(str(weight_quant_state.dtype).strip('torch')))
self.register_buffer('blocksize', torch.tensor(weight_quant_state.blocksize))
self.register_buffer('quant_type', string_to_tensor(weight_quant_state.quant_type))
self.register_buffer('code', weight_quant_state.code)
if weight_quant_state.nested:
self.register_buffer('nested_offset', weight_quant_state.offset)
self.register_buffer('nested_absmax', weight_quant_state.state2.absmax)
self.register_buffer('nested_code', weight_quant_state.state2.code)
self.register_buffer('nested_blocksize', torch.tensor(weight_quant_state.state2.blocksize))
self.register_buffer('nested_dtype', string_to_tensor(str(weight_quant_state.state2.dtype).strip('torch')))
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars):
""" """
fill state_dict with components of nf4 besides weight and bias,
TODO: test with other 4-bit Q-types fill state_dict with components of quant_state
""" """
self._update_buffers() # link the quant_state items with _buffers if getattr(self.weight, "quant_state", None) is not None:
quant_state_dict = self.weight.quant_state.as_dict()
tensor_keys = [k for k, v in quant_state_dict.items() if isinstance(v, torch.Tensor)]
for k in tensor_keys:
destination[prefix + "weight." + k] = quant_state_dict.pop(k) if keep_vars else quant_state_dict.pop(k).detach()
destination[prefix + "weight." + "quant_state_dict"] = quant_state_dict
destination[prefix + "weight." + "quantization_method"] = "bitsandbytes." + quant_state_dict["quant_type"]
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment