Unverified Commit 726f1470 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #864 from poedator/save4_fixes

fixes to recent PR #753
parents f1ef74f8 54860539
...@@ -567,13 +567,13 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n ...@@ -567,13 +567,13 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out return out
class QuantState: class QuantState:
"""container for quantization state components to work with Params4bit and similar clases""" """container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types = ('fp4', 'nf4') valid_quant_types = ('fp4', 'nf4')
valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types] valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type',
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax self.absmax = absmax
...@@ -611,16 +611,19 @@ class QuantState: ...@@ -611,16 +611,19 @@ class QuantState:
""" """
# unpacking tensor with non-tensor components # unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if k in cls.valid_qs_type_keys and isinstance(v, torch.Tensor)] qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
if not len(qs_key) and 'quant_type' not in qs_dict: if not len(qs_key) and 'quant_type' not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither") raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1: elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
raise ValueError(f"There should be exaclly one quant_state item with key from {self.valid_qs_type_keys}. Detected {len(qs_ley)} such items") raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.")
# unpacking minor and non-tensor quant state items if necessary # unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1: if len(qs_key) == 1:
qs_key = qs_key[0] qs_key = qs_key[0]
qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key)) qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(qs_key)))
qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
if 'nested_absmax' in qs_dict: if 'nested_absmax' in qs_dict:
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
...@@ -682,6 +685,7 @@ class QuantState: ...@@ -682,6 +685,7 @@ class QuantState:
self.state2.absmax = self.state2.absmax.to(device) self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device) self.state2.code = self.state2.code.to(device)
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
""" """
Quantize tensor A in blocks of size 4096 values. Quantize tensor A in blocks of size 4096 values.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Optional, TypeVar, Union, overload from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings import warnings
import torch import torch
...@@ -139,9 +139,10 @@ class Embedding(torch.nn.Embedding): ...@@ -139,9 +139,10 @@ class Embedding(torch.nn.Embedding):
return emb return emb
class Params4bit(torch.nn.Parameter): class Params4bit(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'): def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit":
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
...@@ -154,25 +155,14 @@ class Params4bit(torch.nn.Parameter): ...@@ -154,25 +155,14 @@ class Params4bit(torch.nn.Parameter):
return self return self
@classmethod @classmethod
def from_state_dict(cls, state_dict, prefix="", requires_grad=False): def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
data = state_dict.pop(prefix.rstrip('.')) self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
# extracting components for QuantState from state_dict self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
qs_dict = {} self.blocksize = self.quant_state.blocksize
for k, v in state_dict.items(): self.compress_statistics = self.quant_state.nested
if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys: self.quant_type = self.quant_state.quant_type
qs_dict[k] = v return self
state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
if data.device.type != "cuda":
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
cls.requires_grad = requires_grad,
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
return self, state_dict
def cuda(self, device): def cuda(self, device):
w = self.data.contiguous().half().cuda(device) w = self.data.contiguous().half().cuda(device)
...@@ -210,9 +200,10 @@ class Params4bit(torch.nn.Parameter): ...@@ -210,9 +200,10 @@ class Params4bit(torch.nn.Parameter):
return new_param return new_param
class Linear4bit(nn.Linear): class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None):
super().__init__(input_features, output_features, bias, device) super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
# self.persistent_buffers = [] # TODO consider as way to save quant state # self.persistent_buffers = [] # TODO consider as way to save quant state
...@@ -246,18 +237,6 @@ class Linear4bit(nn.Linear): ...@@ -246,18 +237,6 @@ class Linear4bit(nn.Linear):
for k, v in self.weight.quant_state.as_dict(packed=True).items(): for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach() destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# Note: super()._load_from_state_dict() is not called here intentionally.
if self.bias is not None:
bias_data = state_dict.pop(prefix + "bias", None)
self.bias.data = bias_data.to(self.bias.data.device)
self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
state_dict, prefix=prefix + "weight" + ".", requires_grad=False
)
unexpected_keys.extend(state_dict.keys())
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually # weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype: if self.bias is not None and self.bias.dtype != x.dtype:
...@@ -280,10 +259,12 @@ class Linear4bit(nn.Linear): ...@@ -280,10 +259,12 @@ class Linear4bit(nn.Linear):
return out return out
class LinearFP4(Linear4bit): class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
class LinearNF4(Linear4bit): class LinearNF4(Linear4bit):
''' Implements the NF4 data type. ''' Implements the NF4 data type.
...@@ -295,7 +276,7 @@ class LinearNF4(Linear4bit): ...@@ -295,7 +276,7 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
''' '''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
......
...@@ -7,8 +7,6 @@ import pytest ...@@ -7,8 +7,6 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.nn.modules import Linear4bit
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
...@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
# saving to state_dict: # saving to state_dict:
sd = linear_q.state_dict() sd = linear_q.state_dict()
# restoring from state_dict:
bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
# creating new layer with same params: # creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit( linear_q2 = bnb.nn.Linear4bit(
linear.in_features, linear.in_features,
...@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device=device, # TODO create on meta device to save loading time device=device, # TODO create on meta device to save loading time
) )
# loading weights from state_dict: # loading weights from state_dict:
linear_q2.load_state_dict(sd) linear_q2.weight = weight2.to(device)
if bias:
linear_q2.bias = torch.nn.Parameter(bias_data2)
# MATCHING # MATCHING
a, b = linear_q.weight, linear_q2.weight a, b = linear_q.weight, linear_q2.weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment