Unverified Commit 726f1470 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #864 from poedator/save4_fixes

fixes to recent PR #753
parents f1ef74f8 54860539
...@@ -567,14 +567,14 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n ...@@ -567,14 +567,14 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out return out
class QuantState: class QuantState:
"""container for quantization state components to work with Params4bit and similar clases""" """container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types = ('fp4', 'nf4') valid_quant_types = ('fp4', 'nf4')
valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types] valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type',
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax self.absmax = absmax
self.shape = shape self.shape = shape
...@@ -585,7 +585,7 @@ class QuantState: ...@@ -585,7 +585,7 @@ class QuantState:
self.offset = offset self.offset = offset
self.state2 = state2 self.state2 = state2
self.nested = state2 is not None self.nested = state2 is not None
def __get_item__(self, idx): def __get_item__(self, idx):
""" """
ensures compatibility with older quant state scheme with nested lists. ensures compatibility with older quant state scheme with nested lists.
...@@ -598,7 +598,7 @@ class QuantState: ...@@ -598,7 +598,7 @@ class QuantState:
else: else:
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type]
return list_repr[idx] return list_repr[idx]
@classmethod @classmethod
def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState': def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState':
""" """
...@@ -606,21 +606,24 @@ class QuantState: ...@@ -606,21 +606,24 @@ class QuantState:
where necessary, convert into strings, torch.dtype, ints, etc. where necessary, convert into strings, torch.dtype, ints, etc.
qs_dict: based on state_dict, with only relevant keys, striped of prefixes. qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
""" """
# unpacking tensor with non-tensor components # unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if k in cls.valid_qs_type_keys and isinstance(v, torch.Tensor)] qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
if not len(qs_key) and 'quant_type' not in qs_dict: if not len(qs_key) and 'quant_type' not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither") raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1: elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
raise ValueError(f"There should be exaclly one quant_state item with key from {self.valid_qs_type_keys}. Detected {len(qs_ley)} such items") raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.")
# unpacking minor and non-tensor quant state items if necessary # unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1: if len(qs_key) == 1:
qs_key = qs_key[0] qs_key = qs_key[0]
qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key)) qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(qs_key)))
qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
if 'nested_absmax' in qs_dict: if 'nested_absmax' in qs_dict:
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
...@@ -654,7 +657,7 @@ class QuantState: ...@@ -654,7 +657,7 @@ class QuantState:
'quant_type': self.quant_type, 'quant_type': self.quant_type,
'absmax': self.absmax, 'absmax': self.absmax,
'blocksize': self.blocksize, 'blocksize': self.blocksize,
'quant_map': self.code, 'quant_map': self.code,
'dtype': str(self.dtype).strip('torch.'), 'dtype': str(self.dtype).strip('torch.'),
'shape': tuple(self.shape) if self.nested else None, 'shape': tuple(self.shape) if self.nested else None,
} }
...@@ -673,7 +676,7 @@ class QuantState: ...@@ -673,7 +676,7 @@ class QuantState:
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
return qs_packed_dict return qs_packed_dict
def to(self, device): def to(self, device):
# make sure the quantization state is on the right device # make sure the quantization state is on the right device
self.absmax = self.absmax.to(device) self.absmax = self.absmax.to(device)
...@@ -682,6 +685,7 @@ class QuantState: ...@@ -682,6 +685,7 @@ class QuantState:
self.state2.absmax = self.state2.absmax.to(device) self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device) self.state2.code = self.state2.code.to(device)
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
""" """
Quantize tensor A in blocks of size 4096 values. Quantize tensor A in blocks of size 4096 values.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Optional, TypeVar, Union, overload from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings import warnings
import torch import torch
...@@ -139,9 +139,10 @@ class Embedding(torch.nn.Embedding): ...@@ -139,9 +139,10 @@ class Embedding(torch.nn.Embedding):
return emb return emb
class Params4bit(torch.nn.Parameter): class Params4bit(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'): def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit":
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
...@@ -152,27 +153,16 @@ class Params4bit(torch.nn.Parameter): ...@@ -152,27 +153,16 @@ class Params4bit(torch.nn.Parameter):
self.quant_state = quant_state self.quant_state = quant_state
self.data = data self.data = data
return self return self
@classmethod @classmethod
def from_state_dict(cls, state_dict, prefix="", requires_grad=False): def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
data = state_dict.pop(prefix.rstrip('.')) self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
# extracting components for QuantState from state_dict self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
qs_dict = {} self.blocksize = self.quant_state.blocksize
for k, v in state_dict.items(): self.compress_statistics = self.quant_state.nested
if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys: self.quant_type = self.quant_state.quant_type
qs_dict[k] = v return self
state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
if data.device.type != "cuda":
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
cls.requires_grad = requires_grad,
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
return self, state_dict
def cuda(self, device): def cuda(self, device):
w = self.data.contiguous().half().cuda(device) w = self.data.contiguous().half().cuda(device)
...@@ -204,15 +194,16 @@ class Params4bit(torch.nn.Parameter): ...@@ -204,15 +194,16 @@ class Params4bit(torch.nn.Parameter):
self.quant_state.to(device) self.quant_state.to(device)
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad, quant_state=self.quant_state, requires_grad=self.requires_grad, quant_state=self.quant_state,
blocksize=self.blocksize, compress_statistics=self.compress_statistics, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type) quant_type=self.quant_type)
return new_param return new_param
class Linear4bit(nn.Linear): class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None):
super().__init__(input_features, output_features, bias, device) super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
# self.persistent_buffers = [] # TODO consider as way to save quant state # self.persistent_buffers = [] # TODO consider as way to save quant state
...@@ -246,18 +237,6 @@ class Linear4bit(nn.Linear): ...@@ -246,18 +237,6 @@ class Linear4bit(nn.Linear):
for k, v in self.weight.quant_state.as_dict(packed=True).items(): for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach() destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# Note: super()._load_from_state_dict() is not called here intentionally.
if self.bias is not None:
bias_data = state_dict.pop(prefix + "bias", None)
self.bias.data = bias_data.to(self.bias.data.device)
self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
state_dict, prefix=prefix + "weight" + ".", requires_grad=False
)
unexpected_keys.extend(state_dict.keys())
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually # weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype: if self.bias is not None and self.bias.dtype != x.dtype:
...@@ -280,10 +259,12 @@ class Linear4bit(nn.Linear): ...@@ -280,10 +259,12 @@ class Linear4bit(nn.Linear):
return out return out
class LinearFP4(Linear4bit): class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
class LinearNF4(Linear4bit): class LinearNF4(Linear4bit):
''' Implements the NF4 data type. ''' Implements the NF4 data type.
...@@ -295,7 +276,7 @@ class LinearNF4(Linear4bit): ...@@ -295,7 +276,7 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
''' '''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
......
...@@ -7,8 +7,6 @@ import pytest ...@@ -7,8 +7,6 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.nn.modules import Linear4bit
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
...@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
# saving to state_dict: # saving to state_dict:
sd = linear_q.state_dict() sd = linear_q.state_dict()
# restoring from state_dict:
bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
# creating new layer with same params: # creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit( linear_q2 = bnb.nn.Linear4bit(
linear.in_features, linear.in_features,
...@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device=device, # TODO create on meta device to save loading time device=device, # TODO create on meta device to save loading time
) )
# loading weights from state_dict: # loading weights from state_dict:
linear_q2.load_state_dict(sd) linear_q2.weight = weight2.to(device)
if bias:
linear_q2.bias = torch.nn.Parameter(bias_data2)
# MATCHING # MATCHING
a, b = linear_q.weight, linear_q2.weight a, b = linear_q.weight, linear_q2.weight
...@@ -61,7 +64,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -61,7 +64,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
assert a.device == b.device assert a.device == b.device
assert a.dtype == b.dtype assert a.dtype == b.dtype
assert torch.equal(a, b) assert torch.equal(a, b)
q0 = a.quant_state q0 = a.quant_state
q1 = b.quant_state q1 = b.quant_state
for attr in ('code', 'dtype', 'blocksize', 'absmax'): for attr in ('code', 'dtype', 'blocksize', 'absmax'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment