Commit 070f45d2 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

cleanup commented out deletions

parent 781fcd5b
......@@ -163,30 +163,6 @@ class Params4bit(torch.nn.Parameter):
self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type
return self
# @classmethod
# def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
# data = state_dict.pop(prefix.rstrip('.'))
# # extracting components for QuantState from state_dict
# qs_dict = {}
# for k, v in state_dict.items():
# if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
# qs_dict[k] = v
# state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
# qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
# if data.device.type != "cuda":
# raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
# cls.requires_grad = requires_grad
# cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
# cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
# cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
# cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
# self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
# return self, state_dict
def cuda(self, device):
w = self.data.contiguous().half().cuda(device)
......@@ -227,7 +203,7 @@ class Params4bit(torch.nn.Parameter):
class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None):
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
# self.persistent_buffers = [] # TODO consider as way to save quant state
......@@ -261,18 +237,6 @@ class Linear4bit(nn.Linear):
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
# def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
# missing_keys, unexpected_keys, error_msgs):
# # Note: super()._load_from_state_dict() is not called here intentionally.
# if self.bias is not None:
# bias_data = state_dict.pop(prefix + "bias", None)
# self.bias.data = bias_data.to(self.bias.data.device)
# self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
# state_dict, prefix=prefix + "weight" + ".", requires_grad=False
# )
# unexpected_keys.extend(state_dict.keys())
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
......@@ -295,10 +259,12 @@ class Linear4bit(nn.Linear):
return out
class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
class LinearNF4(Linear4bit):
''' Implements the NF4 data type.
......@@ -310,7 +276,7 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
'''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment