Commit c6d0a847 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

cleanup 0

parent 7d1c9cfe
......@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out
class QuantState:
"""container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types = ('fp4', 'nf4')
......@@ -574,7 +575,6 @@ class QuantState:
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state',
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax
self.shape = shape
......@@ -615,7 +615,7 @@ class QuantState:
if not len(qs_key) and 'quant_type' not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1:
raise ValueError(f"There should be exaclly one quant_state item with key from {self.valid_qs_type_keys}. Detected {len(qs_ley)} such items")
raise ValueError(f"There should be exaclly one quant_state item with key from {cls.valid_qs_type_keys}. Detected {len(qs_key)} such items")
# unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1:
......@@ -682,6 +682,7 @@ class QuantState:
self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device)
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
"""
Quantize tensor A in blocks of size 4096 values.
......
......@@ -139,6 +139,7 @@ class Embedding(torch.nn.Embedding):
return emb
class Params4bit(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
......@@ -170,9 +171,9 @@ class Params4bit(torch.nn.Parameter):
cls.requires_grad = requires_grad
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
cls.blocksize = cls.quant_state.blocksize
cls.compress_statistics = cls.quant_state.nested
cls.quant_type = cls.quant_state.quant_type
cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
return self, state_dict
......@@ -213,6 +214,7 @@ class Params4bit(torch.nn.Parameter):
return new_param
class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment