Unverified Commit 726f1470 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #864 from poedator/save4_fixes

fixes to recent PR #753
parents f1ef74f8 54860539
......@@ -567,13 +567,13 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out
class QuantState:
"""container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types = ('fp4', 'nf4')
valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state',
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type',
'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax
......@@ -611,16 +611,19 @@ class QuantState:
"""
# unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if k in cls.valid_qs_type_keys and isinstance(v, torch.Tensor)]
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
if not len(qs_key) and 'quant_type' not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1:
raise ValueError(f"There should be exaclly one quant_state item with key from {self.valid_qs_type_keys}. Detected {len(qs_ley)} such items")
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.")
# unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1:
qs_key = qs_key[0]
qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key))
qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(qs_key)))
qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
if 'nested_absmax' in qs_dict:
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
......@@ -682,6 +685,7 @@ class QuantState:
self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device)
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
"""
Quantize tensor A in blocks of size 4096 values.
......
......@@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, TypeVar, Union, overload
from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings
import torch
......@@ -139,9 +139,10 @@ class Embedding(torch.nn.Embedding):
return emb
class Params4bit(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit":
if data is None:
data = torch.empty(0)
......@@ -154,25 +155,14 @@ class Params4bit(torch.nn.Parameter):
return self
@classmethod
def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
data = state_dict.pop(prefix.rstrip('.'))
# extracting components for QuantState from state_dict
qs_dict = {}
for k, v in state_dict.items():
if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
qs_dict[k] = v
state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
if data.device.type != "cuda":
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
cls.requires_grad = requires_grad,
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
return self, state_dict
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
self.blocksize = self.quant_state.blocksize
self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type
return self
def cuda(self, device):
w = self.data.contiguous().half().cuda(device)
......@@ -210,9 +200,10 @@ class Params4bit(torch.nn.Parameter):
return new_param
class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None):
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
# self.persistent_buffers = [] # TODO consider as way to save quant state
......@@ -246,18 +237,6 @@ class Linear4bit(nn.Linear):
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# Note: super()._load_from_state_dict() is not called here intentionally.
if self.bias is not None:
bias_data = state_dict.pop(prefix + "bias", None)
self.bias.data = bias_data.to(self.bias.data.device)
self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
state_dict, prefix=prefix + "weight" + ".", requires_grad=False
)
unexpected_keys.extend(state_dict.keys())
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
......@@ -280,10 +259,12 @@ class Linear4bit(nn.Linear):
return out
class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
class LinearNF4(Linear4bit):
''' Implements the NF4 data type.
......@@ -295,7 +276,7 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
'''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
......
......@@ -7,8 +7,6 @@ import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.nn.modules import Linear4bit
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
......@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
# saving to state_dict:
sd = linear_q.state_dict()
# restoring from state_dict:
bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
# creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit(
linear.in_features,
......@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device=device, # TODO create on meta device to save loading time
)
# loading weights from state_dict:
linear_q2.load_state_dict(sd)
linear_q2.weight = weight2.to(device)
if bias:
linear_q2.bias = torch.nn.Parameter(bias_data2)
# MATCHING
a, b = linear_q.weight, linear_q2.weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment