Commit 7a117e44 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

cleanup1

parent 6cf0f05d
......@@ -593,6 +593,8 @@ class QuantState:
# unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
assert len(qs_key) == 1 or not qs_key and 'quant_type' in qs_dict, \
f"`qs_dict` must contain packed quant_state items, or be unpacked. Found keys: {tuple(qs_dict.keys())}"
if len(qs_key) == 1:
qs_key = qs_key[0]
assert 'bitsandbytes__nf4' in qs_key or 'bitsandbytes__fp4' in qs_key, \
......@@ -605,22 +607,22 @@ class QuantState:
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
state2 = cls(
absmax=qs_dict['nested_absmax'].to(device),
code=qs_dict['nested_code'].to(device),
blocksize=qs_dict['nested_blocksize'],
code=qs_dict['nested_code'].to(device),
dtype=getattr(torch, qs_dict['nested_dtype']),
)
else:
offset, state2 = None, None
quant_state = cls(
quant_type=qs_dict['quant_type'],
absmax=qs_dict['absmax'].to(device),
shape=torch.Size(qs_dict['shape']),
dtype=getattr(torch, qs_dict['dtype']),
blocksize=qs_dict['blocksize'],
code=qs_dict['code'].to(device),
dtype=getattr(torch, qs_dict['dtype']),
shape=torch.Size(qs_dict['shape']),
offset=offset,
state2=state2,
quant_type=qs_dict['quant_type'],
code=qs_dict['code'].to(device),
)
return quant_state
......@@ -630,20 +632,20 @@ class QuantState:
param: packed -- returns dict[str, torch.Tensor] for state_dict
"""
qs_dict = {
'quant_type': self.quant_type,
'absmax': self.absmax,
'blocksize': self.blocksize,
'code': self.code,
'shape': tuple(self.shape),
'dtype': str(self.dtype).strip('torch.'),
'blocksize': self.blocksize,
'quant_type': self.quant_type,
'shape': tuple(self.shape) if self.nested else None,
}
if self.nested:
qs_dict.update({
'nested_absmax': self.state2.absmax,
'nested_code': self.state2.code,
'nested_offset': self.offset.item(),
'nested_blocksize': self.state2.blocksize,
'nested_code': self.state2.code,
'nested_dtype': str(self.state2.dtype).strip('torch.'),
'nested_offset': self.offset.item(),
})
if not packed:
return qs_dict
......
......@@ -156,7 +156,8 @@ class Params4bit(torch.nn.Parameter):
@classmethod
def from_prequantized(cls, quantized_stats, data=None, requires_grad=False, device='cuda', **kwargs):
if data is None:
data = quantized_stats.pop('weight')
weight_key = [k for k in quantized_stats if k.endswith(".weight")][0]
data = quantized_stats.pop(weight_key)
self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment