Commit c6d0a847 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

cleanup 0

parent 7d1c9cfe
...@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n ...@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out return out
class QuantState: class QuantState:
"""container for quantization state components to work with Params4bit and similar clases""" """container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types = ('fp4', 'nf4') valid_quant_types = ('fp4', 'nf4')
...@@ -574,7 +575,6 @@ class QuantState: ...@@ -574,7 +575,6 @@ class QuantState:
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state',
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] 'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax self.absmax = absmax
self.shape = shape self.shape = shape
...@@ -585,7 +585,7 @@ class QuantState: ...@@ -585,7 +585,7 @@ class QuantState:
self.offset = offset self.offset = offset
self.state2 = state2 self.state2 = state2
self.nested = state2 is not None self.nested = state2 is not None
def __get_item__(self, idx): def __get_item__(self, idx):
""" """
ensures compatibility with older quant state scheme with nested lists. ensures compatibility with older quant state scheme with nested lists.
...@@ -598,7 +598,7 @@ class QuantState: ...@@ -598,7 +598,7 @@ class QuantState:
else: else:
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type]
return list_repr[idx] return list_repr[idx]
@classmethod @classmethod
def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState': def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState':
""" """
...@@ -606,7 +606,7 @@ class QuantState: ...@@ -606,7 +606,7 @@ class QuantState:
where necessary, convert into strings, torch.dtype, ints, etc. where necessary, convert into strings, torch.dtype, ints, etc.
qs_dict: based on state_dict, with only relevant keys, striped of prefixes. qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
""" """
...@@ -615,8 +615,8 @@ class QuantState: ...@@ -615,8 +615,8 @@ class QuantState:
if not len(qs_key) and 'quant_type' not in qs_dict: if not len(qs_key) and 'quant_type' not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither") raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1: elif len(qs_key) != 1:
raise ValueError(f"There should be exaclly one quant_state item with key from {self.valid_qs_type_keys}. Detected {len(qs_ley)} such items") raise ValueError(f"There should be exaclly one quant_state item with key from {cls.valid_qs_type_keys}. Detected {len(qs_key)} such items")
# unpacking minor and non-tensor quant state items if necessary # unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1: if len(qs_key) == 1:
qs_key = qs_key[0] qs_key = qs_key[0]
...@@ -673,7 +673,7 @@ class QuantState: ...@@ -673,7 +673,7 @@ class QuantState:
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
return qs_packed_dict return qs_packed_dict
def to(self, device): def to(self, device):
# make sure the quantization state is on the right device # make sure the quantization state is on the right device
self.absmax = self.absmax.to(device) self.absmax = self.absmax.to(device)
...@@ -682,6 +682,7 @@ class QuantState: ...@@ -682,6 +682,7 @@ class QuantState:
self.state2.absmax = self.state2.absmax.to(device) self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device) self.state2.code = self.state2.code.to(device)
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
""" """
Quantize tensor A in blocks of size 4096 values. Quantize tensor A in blocks of size 4096 values.
......
...@@ -139,6 +139,7 @@ class Embedding(torch.nn.Embedding): ...@@ -139,6 +139,7 @@ class Embedding(torch.nn.Embedding):
return emb return emb
class Params4bit(torch.nn.Parameter): class Params4bit(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'): def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
...@@ -152,11 +153,11 @@ class Params4bit(torch.nn.Parameter): ...@@ -152,11 +153,11 @@ class Params4bit(torch.nn.Parameter):
self.quant_state = quant_state self.quant_state = quant_state
self.data = data self.data = data
return self return self
@classmethod @classmethod
def from_state_dict(cls, state_dict, prefix="", requires_grad=False): def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
data = state_dict.pop(prefix.rstrip('.')) data = state_dict.pop(prefix.rstrip('.'))
# extracting components for QuantState from state_dict # extracting components for QuantState from state_dict
qs_dict = {} qs_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
...@@ -164,15 +165,15 @@ class Params4bit(torch.nn.Parameter): ...@@ -164,15 +165,15 @@ class Params4bit(torch.nn.Parameter):
qs_dict[k] = v qs_dict[k] = v
state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict} state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()} qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
if data.device.type != "cuda": if data.device.type != "cuda":
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}") raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
cls.requires_grad = requires_grad cls.requires_grad = requires_grad
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device) cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
cls.blocksize = cls.quant_state.blocksize cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
cls.compress_statistics = cls.quant_state.nested cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
cls.quant_type = cls.quant_state.quant_type cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
self = torch.Tensor._make_subclass(cls, data=data.to(data.device)) self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
return self, state_dict return self, state_dict
...@@ -207,14 +208,15 @@ class Params4bit(torch.nn.Parameter): ...@@ -207,14 +208,15 @@ class Params4bit(torch.nn.Parameter):
self.quant_state.to(device) self.quant_state.to(device)
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad, quant_state=self.quant_state, requires_grad=self.requires_grad, quant_state=self.quant_state,
blocksize=self.blocksize, compress_statistics=self.compress_statistics, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type) quant_type=self.quant_type)
return new_param return new_param
class Linear4bit(nn.Linear): class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
super().__init__(input_features, output_features, bias, device) super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment