"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "03df281275ad3fcb732a41ab1638c2e89afddb25"
Commit 48b3e770 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

some renaming

parent 5bcc1ddc
...@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n ...@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out return out
class QuantState: class QuantState:
"""container for quantizationstate components to work with Params4bit and similar clases"""
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax self.absmax = absmax
self.shape = shape self.shape = shape
...@@ -579,32 +580,35 @@ class QuantState: ...@@ -579,32 +580,35 @@ class QuantState:
self.nested = state2 is not None self.nested = state2 is not None
@classmethod @classmethod
def from_kwargs(cls, kwargs, device): def from_dict(cls, quant_state_dict: dict[str, torch.Tensor], device: torch.device) -> 'QuantState':
"""
unpacks dict of tensors into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
"""
tensor2str = lambda xx: ''.join([chr(x) for x in xx]).strip('.') tensor2str = lambda xx: ''.join([chr(x) for x in xx]).strip('.')
kwargs = {k.split('.')[-1] :v for k, v in kwargs.items()} quant_state_dict = {k.split('.')[-1] :v for k, v in quant_state_dict.items()}
if 'nested_absmax' in kwargs: if 'nested_absmax' in quant_state_dict:
offset = kwargs['nested_offset'] offset = quant_state_dict['nested_offset']
state2 = cls( state2 = cls(
absmax=kwargs['nested_absmax'].to(device), absmax=quant_state_dict['nested_absmax'].to(device),
code=kwargs['nested_code'].to(device), code=quant_state_dict['nested_code'].to(device),
blocksize=kwargs['nested_blocksize'].item(), blocksize=quant_state_dict['nested_blocksize'].item(),
dtype=getattr(torch, tensor2str(kwargs['nested_dtype'])), dtype=getattr(torch, tensor2str(quant_state_dict['nested_dtype'])),
) )
else: else:
offset, state2 = None, None offset, state2 = None, None
quant_state = cls( quant_state = cls(
absmax=kwargs['absmax'].to(device), absmax=quant_state_dict['absmax'].to(device),
shape=torch.Size(kwargs['shape']), shape=torch.Size(quant_state_dict['shape']),
dtype=getattr(torch, tensor2str(kwargs['dtype'])), dtype=getattr(torch, tensor2str(quant_state_dict['dtype'])),
blocksize=kwargs['blocksize'].item(), blocksize=quant_state_dict['blocksize'].item(),
offset=offset, offset=offset,
state2=state2, state2=state2,
quant_type=tensor2str(kwargs['quant_type']), quant_type=tensor2str(quant_state_dict['quant_type']),
code=kwargs['code'].to(device), code=quant_state_dict['code'].to(device),
) )
return quant_state return quant_state
......
...@@ -159,7 +159,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -159,7 +159,7 @@ class Params4bit(torch.nn.Parameter):
data = quantized_stats.pop('weight') data = quantized_stats.pop('weight')
self = torch.Tensor._make_subclass(cls, data.to(device)) self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.quant_state = QuantState.from_kwargs(kwargs=quantized_stats, device=device) self.quant_state = QuantState.from_dict(quant_state_dict=quantized_stats, device=device)
self.blocksize = self.quant_state.blocksize self.blocksize = self.quant_state.blocksize
self.compress_statistics = self.quant_state.nested self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type self.quant_type = self.quant_state.quant_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment