Commit 851806e0 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

renamed code to `quant_map` in serialized QState

parent 76b40a5c
...@@ -571,14 +571,14 @@ class QuantState: ...@@ -571,14 +571,14 @@ class QuantState:
"""container for quantization state components to work with Params4bit and similar clases""" """container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types = ('fp4', 'nf4') valid_quant_types = ('fp4', 'nf4')
valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types] valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = ['absmax', 'code', 'nested_absmax', 'nested_code', 'quant_state', valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state',
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] 'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax self.absmax = absmax
self.shape = shape self.shape = shape
self.code = code # TODO consider renaming to `buckets / centroids / scale` self.code = code
self.dtype = dtype self.dtype = dtype
self.blocksize = blocksize self.blocksize = blocksize
self.quant_type = quant_type self.quant_type = quant_type
...@@ -627,7 +627,7 @@ class QuantState: ...@@ -627,7 +627,7 @@ class QuantState:
state2 = cls( state2 = cls(
absmax=qs_dict['nested_absmax'].to(device), absmax=qs_dict['nested_absmax'].to(device),
blocksize=qs_dict['nested_blocksize'], blocksize=qs_dict['nested_blocksize'],
code=qs_dict['nested_code'].to(device), code=qs_dict['nested_quant_map'].to(device),
dtype=getattr(torch, qs_dict['nested_dtype']), dtype=getattr(torch, qs_dict['nested_dtype']),
) )
else: else:
...@@ -637,7 +637,7 @@ class QuantState: ...@@ -637,7 +637,7 @@ class QuantState:
quant_type=qs_dict['quant_type'], quant_type=qs_dict['quant_type'],
absmax=qs_dict['absmax'].to(device), absmax=qs_dict['absmax'].to(device),
blocksize=qs_dict['blocksize'], blocksize=qs_dict['blocksize'],
code=qs_dict['code'].to(device), code=qs_dict['quant_map'].to(device),
dtype=getattr(torch, qs_dict['dtype']), dtype=getattr(torch, qs_dict['dtype']),
shape=torch.Size(qs_dict['shape']), shape=torch.Size(qs_dict['shape']),
offset=offset, offset=offset,
...@@ -654,7 +654,7 @@ class QuantState: ...@@ -654,7 +654,7 @@ class QuantState:
'quant_type': self.quant_type, 'quant_type': self.quant_type,
'absmax': self.absmax, 'absmax': self.absmax,
'blocksize': self.blocksize, 'blocksize': self.blocksize,
'code': self.code, 'quant_map': self.code,
'dtype': str(self.dtype).strip('torch.'), 'dtype': str(self.dtype).strip('torch.'),
'shape': tuple(self.shape) if self.nested else None, 'shape': tuple(self.shape) if self.nested else None,
} }
...@@ -662,7 +662,7 @@ class QuantState: ...@@ -662,7 +662,7 @@ class QuantState:
qs_dict.update({ qs_dict.update({
'nested_absmax': self.state2.absmax, 'nested_absmax': self.state2.absmax,
'nested_blocksize': self.state2.blocksize, 'nested_blocksize': self.state2.blocksize,
'nested_code': self.state2.code, 'nested_quant_map': self.state2.code,
'nested_dtype': str(self.state2.dtype).strip('torch.'), 'nested_dtype': str(self.state2.dtype).strip('torch.'),
'nested_offset': self.offset.item(), 'nested_offset': self.offset.item(),
}) })
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment