Commit 7a117e44 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

cleanup1

parent 6cf0f05d
...@@ -593,6 +593,8 @@ class QuantState: ...@@ -593,6 +593,8 @@ class QuantState:
# unpacking tensor with non-tensor components # unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
assert len(qs_key) == 1 or not qs_key and 'quant_type' in qs_dict, \
f"`qs_dict` must contain packed quant_state items, or be unpacked. Found keys: {tuple(qs_dict.keys())}"
if len(qs_key) == 1: if len(qs_key) == 1:
qs_key = qs_key[0] qs_key = qs_key[0]
assert 'bitsandbytes__nf4' in qs_key or 'bitsandbytes__fp4' in qs_key, \ assert 'bitsandbytes__nf4' in qs_key or 'bitsandbytes__fp4' in qs_key, \
...@@ -605,22 +607,22 @@ class QuantState: ...@@ -605,22 +607,22 @@ class QuantState:
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
state2 = cls( state2 = cls(
absmax=qs_dict['nested_absmax'].to(device), absmax=qs_dict['nested_absmax'].to(device),
code=qs_dict['nested_code'].to(device),
blocksize=qs_dict['nested_blocksize'], blocksize=qs_dict['nested_blocksize'],
code=qs_dict['nested_code'].to(device),
dtype=getattr(torch, qs_dict['nested_dtype']), dtype=getattr(torch, qs_dict['nested_dtype']),
) )
else: else:
offset, state2 = None, None offset, state2 = None, None
quant_state = cls( quant_state = cls(
quant_type=qs_dict['quant_type'],
absmax=qs_dict['absmax'].to(device), absmax=qs_dict['absmax'].to(device),
shape=torch.Size(qs_dict['shape']),
dtype=getattr(torch, qs_dict['dtype']),
blocksize=qs_dict['blocksize'], blocksize=qs_dict['blocksize'],
code=qs_dict['code'].to(device),
dtype=getattr(torch, qs_dict['dtype']),
shape=torch.Size(qs_dict['shape']),
offset=offset, offset=offset,
state2=state2, state2=state2,
quant_type=qs_dict['quant_type'],
code=qs_dict['code'].to(device),
) )
return quant_state return quant_state
...@@ -630,20 +632,20 @@ class QuantState: ...@@ -630,20 +632,20 @@ class QuantState:
param: packed -- returns dict[str, torch.Tensor] for state_dict param: packed -- returns dict[str, torch.Tensor] for state_dict
""" """
qs_dict = { qs_dict = {
'quant_type': self.quant_type,
'absmax': self.absmax, 'absmax': self.absmax,
'blocksize': self.blocksize,
'code': self.code, 'code': self.code,
'shape': tuple(self.shape),
'dtype': str(self.dtype).strip('torch.'), 'dtype': str(self.dtype).strip('torch.'),
'blocksize': self.blocksize, 'shape': tuple(self.shape) if self.nested else None,
'quant_type': self.quant_type,
} }
if self.nested: if self.nested:
qs_dict.update({ qs_dict.update({
'nested_absmax': self.state2.absmax, 'nested_absmax': self.state2.absmax,
'nested_code': self.state2.code,
'nested_offset': self.offset.item(),
'nested_blocksize': self.state2.blocksize, 'nested_blocksize': self.state2.blocksize,
'nested_code': self.state2.code,
'nested_dtype': str(self.state2.dtype).strip('torch.'), 'nested_dtype': str(self.state2.dtype).strip('torch.'),
'nested_offset': self.offset.item(),
}) })
if not packed: if not packed:
return qs_dict return qs_dict
......
...@@ -156,7 +156,8 @@ class Params4bit(torch.nn.Parameter): ...@@ -156,7 +156,8 @@ class Params4bit(torch.nn.Parameter):
@classmethod @classmethod
def from_prequantized(cls, quantized_stats, data=None, requires_grad=False, device='cuda', **kwargs): def from_prequantized(cls, quantized_stats, data=None, requires_grad=False, device='cuda', **kwargs):
if data is None: if data is None:
data = quantized_stats.pop('weight') weight_key = [k for k in quantized_stats if k.endswith(".weight")][0]
data = quantized_stats.pop(weight_key)
self = torch.Tensor._make_subclass(cls, data.to(device)) self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device) self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment