Commit 48b3e770 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

some renaming

parent 5bcc1ddc
......@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out
class QuantState:
"""container for quantizationstate components to work with Params4bit and similar clases"""
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax
self.shape = shape
......@@ -579,32 +580,35 @@ class QuantState:
self.nested = state2 is not None
@classmethod
def from_kwargs(cls, kwargs, device):
def from_dict(cls, quant_state_dict: dict[str, torch.Tensor], device: torch.device) -> 'QuantState':
"""
unpacks dict of tensors into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
"""
tensor2str = lambda xx: ''.join([chr(x) for x in xx]).strip('.')
kwargs = {k.split('.')[-1] :v for k, v in kwargs.items()}
quant_state_dict = {k.split('.')[-1] :v for k, v in quant_state_dict.items()}
if 'nested_absmax' in kwargs:
offset = kwargs['nested_offset']
if 'nested_absmax' in quant_state_dict:
offset = quant_state_dict['nested_offset']
state2 = cls(
absmax=kwargs['nested_absmax'].to(device),
code=kwargs['nested_code'].to(device),
blocksize=kwargs['nested_blocksize'].item(),
dtype=getattr(torch, tensor2str(kwargs['nested_dtype'])),
absmax=quant_state_dict['nested_absmax'].to(device),
code=quant_state_dict['nested_code'].to(device),
blocksize=quant_state_dict['nested_blocksize'].item(),
dtype=getattr(torch, tensor2str(quant_state_dict['nested_dtype'])),
)
else:
offset, state2 = None, None
quant_state = cls(
absmax=kwargs['absmax'].to(device),
shape=torch.Size(kwargs['shape']),
dtype=getattr(torch, tensor2str(kwargs['dtype'])),
blocksize=kwargs['blocksize'].item(),
absmax=quant_state_dict['absmax'].to(device),
shape=torch.Size(quant_state_dict['shape']),
dtype=getattr(torch, tensor2str(quant_state_dict['dtype'])),
blocksize=quant_state_dict['blocksize'].item(),
offset=offset,
state2=state2,
quant_type=tensor2str(kwargs['quant_type']),
code=kwargs['code'].to(device),
quant_type=tensor2str(quant_state_dict['quant_type']),
code=quant_state_dict['code'].to(device),
)
return quant_state
......
......@@ -159,7 +159,7 @@ class Params4bit(torch.nn.Parameter):
data = quantized_stats.pop('weight')
self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
self.quant_state = QuantState.from_kwargs(kwargs=quantized_stats, device=device)
self.quant_state = QuantState.from_dict(quant_state_dict=quantized_stats, device=device)
self.blocksize = self.quant_state.blocksize
self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment