Commit 1d9f0f2a authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

add QuantState.__get_item__() for compatibility

parent 0a0b531f
...@@ -579,7 +579,20 @@ class QuantState: ...@@ -579,7 +579,20 @@ class QuantState:
self.offset = offset self.offset = offset
self.state2 = state2 self.state2 = state2
self.nested = state2 is not None self.nested = state2 is not None
def __get_item__(self, idx):
"""
ensures compatibility with older quant state scheme with nested lists.
assumes the following layout:
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
"""
if self.nested:
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type]
else:
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type]
return list_repr[idx]
@classmethod @classmethod
def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState': def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState':
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment