Unverified Commit ad6eef99 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #753 from poedator/save4

Save and load  in NF4 / FP4 formats
parents e812136c 851806e0
...@@ -133,3 +133,4 @@ dmypy.json ...@@ -133,3 +133,4 @@ dmypy.json
dependencies dependencies
cuda_build cuda_build
.vscode/*
...@@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, state=None): def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None):
# default of pytorch behavior if inputs are empty # default of pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
...@@ -504,7 +504,7 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -504,7 +504,7 @@ class MatMul4Bit(torch.autograd.Function):
ctx.A = A ctx.A = A
ctx.B = B ctx.B = B
ctx.bias = bias ctx.bias = bias
B_shape = state[1] B_shape = quant_state.shape
if A.shape[-1] == B_shape[0]: if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else: else:
...@@ -513,10 +513,10 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -513,10 +513,10 @@ class MatMul4Bit(torch.autograd.Function):
# 1. Dequantize # 1. Dequantize
# 2. MatmulnN # 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias) output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
# 3. Save state # 3. Save state
ctx.state = state ctx.state = quant_state
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]): if any(ctx.needs_input_grad[:2]):
...@@ -534,7 +534,6 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -534,7 +534,6 @@ class MatMul4Bit(torch.autograd.Function):
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
A, B = ctx.tensors A, B = ctx.tensors
state = ctx.state
grad_A, grad_B, grad_bias = None, None, None grad_A, grad_B, grad_bias = None, None, None
...@@ -563,12 +562,11 @@ def matmul( ...@@ -563,12 +562,11 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None):
assert quant_state is not None assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False: if A.numel() == A.shape[-1] and A.requires_grad == False:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state if A.shape[-1] % quant_state.blocksize != 0:
if A.shape[-1] % blocksize != 0: warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
return MatMul4Bit.apply(A, B, out, bias, quant_state) return MatMul4Bit.apply(A, B, out, bias, quant_state)
else: else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state) out = F.gemv_4bit(A, B.t(), out, state=quant_state)
......
...@@ -13,8 +13,9 @@ from scipy.stats import norm ...@@ -13,8 +13,9 @@ from scipy.stats import norm
import numpy as np import numpy as np
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
from typing import Tuple from typing import Tuple, Any
from torch import Tensor from torch import Tensor
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
from .cextension import COMPILED_WITH_CUDA, lib from .cextension import COMPILED_WITH_CUDA, lib
...@@ -566,6 +567,120 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n ...@@ -566,6 +567,120 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out return out
class QuantState:
"""container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types = ('fp4', 'nf4')
valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state',
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax
self.shape = shape
self.code = code
self.dtype = dtype
self.blocksize = blocksize
self.quant_type = quant_type
self.offset = offset
self.state2 = state2
self.nested = state2 is not None
def __get_item__(self, idx):
"""
ensures compatibility with older quant state scheme with nested lists.
assumes the following layout:
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
"""
if self.nested:
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type]
else:
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type]
return list_repr[idx]
@classmethod
def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState':
"""
unpacks components of state_dict into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
"""
# unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if k in cls.valid_qs_type_keys and isinstance(v, torch.Tensor)]
if not len(qs_key) and 'quant_type' not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1:
raise ValueError(f"There should be exaclly one quant_state item with key from {self.valid_qs_type_keys}. Detected {len(qs_ley)} such items")
# unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1:
qs_key = qs_key[0]
qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key))
if 'nested_absmax' in qs_dict:
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
state2 = cls(
absmax=qs_dict['nested_absmax'].to(device),
blocksize=qs_dict['nested_blocksize'],
code=qs_dict['nested_quant_map'].to(device),
dtype=getattr(torch, qs_dict['nested_dtype']),
)
else:
offset, state2 = None, None
quant_state = cls(
quant_type=qs_dict['quant_type'],
absmax=qs_dict['absmax'].to(device),
blocksize=qs_dict['blocksize'],
code=qs_dict['quant_map'].to(device),
dtype=getattr(torch, qs_dict['dtype']),
shape=torch.Size(qs_dict['shape']),
offset=offset,
state2=state2,
)
return quant_state
def as_dict(self, packed=False):
"""
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
param: packed -- returns dict[str, torch.Tensor] for state_dict
"""
qs_dict = {
'quant_type': self.quant_type,
'absmax': self.absmax,
'blocksize': self.blocksize,
'quant_map': self.code,
'dtype': str(self.dtype).strip('torch.'),
'shape': tuple(self.shape) if self.nested else None,
}
if self.nested:
qs_dict.update({
'nested_absmax': self.state2.absmax,
'nested_blocksize': self.state2.blocksize,
'nested_quant_map': self.state2.code,
'nested_dtype': str(self.state2.dtype).strip('torch.'),
'nested_offset': self.offset.item(),
})
if not packed:
return qs_dict
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
return qs_packed_dict
def to(self, device):
# make sure the quantization state is on the right device
self.absmax = self.absmax.to(device)
if self.nested:
self.offset = self.offset.to(device)
self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device)
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
""" """
...@@ -633,16 +748,16 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou ...@@ -633,16 +748,16 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
offset = absmax.mean() offset = absmax.mean()
absmax -= offset absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
state = [qabsmax, code, blocksize, nested, A.dtype, offset, state2] quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2)
else: else:
state = [absmax, code, blocksize, nested, A.dtype, None, None] quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype)
return out, state return out, quant_state
def dequantize_blockwise( def dequantize_blockwise(
A: Tensor, A: Tensor,
quant_state: Tuple[Tensor, Tensor] = None, quant_state: QuantState = None,
absmax: Tensor = None, absmax: Tensor = None,
code: Tensor = None, code: Tensor = None,
out: Tensor = None, out: Tensor = None,
...@@ -659,8 +774,8 @@ def dequantize_blockwise( ...@@ -659,8 +774,8 @@ def dequantize_blockwise(
---------- ----------
A : torch.Tensor A : torch.Tensor
The input 8-bit tensor. The input 8-bit tensor.
quant_state : tuple(torch.Tensor, torch.Tensor) quant_state : QuantState
Tuple of code and absmax values. Object with code, absmax and other quantization state components.
absmax : torch.Tensor absmax : torch.Tensor
The absmax values. The absmax values.
code : torch.Tensor code : torch.Tensor
...@@ -681,36 +796,35 @@ def dequantize_blockwise( ...@@ -681,36 +796,35 @@ def dequantize_blockwise(
code = name2qmap["dynamic"] code = name2qmap["dynamic"]
if quant_state is None: if quant_state is None:
quant_state = (absmax, code, blocksize, False, torch.float32, None, None) quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32)
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state absmax = quant_state.absmax
if quant_state.nested:
if nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax = dequantize_blockwise(absmax, state2) absmax += quant_state.offset
absmax += offset
if absmax.dtype != torch.float32: absmax = absmax.float() if absmax.dtype != torch.float32: absmax = absmax.float()
if out is None: if out is None:
out = torch.empty(A.shape, dtype=dtype, device=A.device) out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device)
if A.device.type != 'cpu': if A.device.type != 'cpu':
device = pre_call(A.device) device = pre_call(A.device)
code = code.to(A.device) code = quant_state.code.to(A.device)
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
is_on_gpu([A, absmax, out]) is_on_gpu([A, absmax, out])
if out.dtype == torch.float32: if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16: elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.bfloat16: elif out.dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device) post_call(A.device)
else: else:
code = code.cpu() code = quant_state.code.cpu()
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel()))
return out return out
...@@ -765,7 +879,6 @@ def get_4bit_type(typename, device=None, blocksize=64): ...@@ -765,7 +879,6 @@ def get_4bit_type(typename, device=None, blocksize=64):
return data.to(device) return data.to(device)
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4')
...@@ -839,26 +952,26 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz ...@@ -839,26 +952,26 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device) post_call(A.device)
datatype = get_4bit_type(quant_type, device=A.device) code = get_4bit_type(quant_type, device=A.device)
if compress_statistics: if compress_statistics:
offset = absmax.mean() offset = absmax.mean()
absmax -= offset absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
del absmax del absmax
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype] state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2)
else: else:
state = [absmax, input_shape, A.dtype, blocksize, None, quant_type, datatype] state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, )
return out, state return out, state
def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4')
def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4')
def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
""" """
Dequantizes FP4 blockwise quantized values. Dequantizes FP4 blockwise quantized values.
...@@ -868,8 +981,8 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: ...@@ -868,8 +981,8 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
---------- ----------
A : torch.Tensor A : torch.Tensor
The input 8-bit tensor (packed 4-bit values). The input 8-bit tensor (packed 4-bit values).
quant_state : tuple(torch.Tensor, torch.Size, torch.dtype) quant_state : QuantState
Tuple of absmax values, original tensor shape and original dtype. object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor absmax : torch.Tensor
The absmax values. The absmax values.
out : torch.Tensor out : torch.Tensor
...@@ -892,41 +1005,40 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: ...@@ -892,41 +1005,40 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
if quant_state is None: if quant_state is None:
assert absmax is not None and out is not None assert absmax is not None and out is not None
shape = out.shape
dtype = out.dtype quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type)
else: else:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state absmax = quant_state.absmax
if compressed_stats is not None: if quant_state.nested:
offset, state2 = compressed_stats absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax = dequantize_blockwise(absmax, state2) absmax += quant_state.offset
absmax += offset
if absmax.dtype != torch.float32: absmax = absmax.float() if absmax.dtype != torch.float32: absmax = absmax.float()
if out is None: if out is None:
out = torch.empty(shape, dtype=dtype, device=A.device) out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
n = out.numel() n = out.numel()
device = pre_call(A.device) device = pre_call(A.device)
is_on_gpu([A, absmax, out]) is_on_gpu([A, absmax, out])
if out.dtype == torch.float32: if out.dtype == torch.float32:
if quant_type == 'fp4': if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else: else:
lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
elif out.dtype == torch.float16: elif out.dtype == torch.float16:
if quant_type == 'fp4': if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else: else:
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
elif out.dtype == torch.bfloat16: elif out.dtype == torch.bfloat16:
if quant_type == 'fp4': if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else: else:
lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device) post_call(A.device)
...@@ -952,22 +1064,22 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: ...@@ -952,22 +1064,22 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
def dequantize( def dequantize(
A: Tensor, A: Tensor,
quant_state: Tuple[Tensor, Tensor] = None, state: Tuple[Tensor, Tensor] = None,
absmax: Tensor = None, absmax: Tensor = None,
code: Tensor = None, code: Tensor = None,
out: Tensor = None, out: Tensor = None,
) -> Tensor: ) -> Tensor:
assert quant_state is not None or absmax is not None assert state is not None or absmax is not None
if code is None and quant_state is None: if code is None and state is None:
if "dynamic" not in name2qmap: if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device) name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"] code = name2qmap["dynamic"]
code = code.to(A.device) code = code.to(A.device)
if quant_state is None: if state is None:
quant_state = (absmax, code) state = (absmax, code)
out = dequantize_no_absmax(A, quant_state[1], out) out = dequantize_no_absmax(A, state[1], out)
return out * quant_state[0] return out * state[0]
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
...@@ -1482,13 +1594,12 @@ def gemv_4bit( ...@@ -1482,13 +1594,12 @@ def gemv_4bit(
if A.numel() != A.shape[-1]: if A.numel() != A.shape[-1]:
raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
Bshape = state[1] Bshape = state.shape
bout = Bshape[0] bout = Bshape[0]
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state absmax = state.absmax
if compressed_stats is not None: if state.nested:
offset, state2 = compressed_stats absmax = dequantize_blockwise(state.absmax, state.state2)
absmax = dequantize_blockwise(absmax, state2) absmax += state.offset
absmax += offset
if out is None: if out is None:
if len(A.shape) == 3: if len(A.shape) == 3:
...@@ -1502,7 +1613,7 @@ def gemv_4bit( ...@@ -1502,7 +1613,7 @@ def gemv_4bit(
lda = Bshape[0] lda = Bshape[0]
ldc = Bshape[0] ldc = Bshape[0]
ldb = (A.shape[-1]+1)//2 ldb = (A.shape[-1]+1)//2
is_on_gpu([B, A, out, absmax, state[-1]]) is_on_gpu([B, A, out, absmax, state.code])
m = ct.c_int32(m) m = ct.c_int32(m)
n = ct.c_int32(n) n = ct.c_int32(n)
k = ct.c_int32(k) k = ct.c_int32(k)
...@@ -1512,11 +1623,11 @@ def gemv_4bit( ...@@ -1512,11 +1623,11 @@ def gemv_4bit(
if B.dtype == torch.uint8: if B.dtype == torch.uint8:
if A.dtype == torch.float16: if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
elif A.dtype == torch.bfloat16: elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
elif A.dtype == torch.float32: elif A.dtype == torch.float32:
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
else: else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
......
...@@ -10,7 +10,7 @@ import torch.nn.functional as F ...@@ -10,7 +10,7 @@ import torch.nn.functional as F
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional from bitsandbytes.functional import QuantState
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims from bitsandbytes.utils import OutlierTracer, find_outlier_dims
...@@ -140,6 +140,7 @@ class Embedding(torch.nn.Embedding): ...@@ -140,6 +140,7 @@ class Embedding(torch.nn.Embedding):
return emb return emb
class Params4bit(torch.nn.Parameter): class Params4bit(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'): def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
...@@ -152,6 +153,27 @@ class Params4bit(torch.nn.Parameter): ...@@ -152,6 +153,27 @@ class Params4bit(torch.nn.Parameter):
self.data = data self.data = data
return self return self
@classmethod
def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
data = state_dict.pop(prefix.rstrip('.'))
# extracting components for QuantState from state_dict
qs_dict = {}
for k, v in state_dict.items():
if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
qs_dict[k] = v
state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
if data.device.type != "cuda":
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
cls.requires_grad = requires_grad,
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
return self, state_dict
def cuda(self, device): def cuda(self, device):
w = self.data.contiguous().half().cuda(device) w = self.data.contiguous().half().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type)
...@@ -178,22 +200,9 @@ class Params4bit(torch.nn.Parameter): ...@@ -178,22 +200,9 @@ class Params4bit(torch.nn.Parameter):
if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"):
return self.cuda(device) return self.cuda(device)
else: else:
s = self.quant_state if self.quant_state is not None:
if s is not None: self.quant_state.to(device)
# make sure the quantization state is on the right device
s[0] = s[0].to(device)
if self.compress_statistics:
# TODO: refactor this. This is a nightmare
# for 4-bit:
# state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
# state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
#s[-2][0] = s[-2][0].to(device) # offset
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
# for 8-bit
s[-3][0] = s[-3][0].to(device) # offset
s[-3][1][0] = s[-3][1][0].to(device) # nested quantiation state statitics
s[-3][1][1] = s[-3][1][1].to(device) # nested quantiation codebook
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad, quant_state=self.quant_state, requires_grad=self.requires_grad, quant_state=self.quant_state,
blocksize=self.blocksize, compress_statistics=self.compress_statistics, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
...@@ -202,9 +211,11 @@ class Params4bit(torch.nn.Parameter): ...@@ -202,9 +211,11 @@ class Params4bit(torch.nn.Parameter):
return new_param return new_param
class Linear4bit(nn.Linear): class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
super().__init__(input_features, output_features, bias, device) super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
# self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.compute_type_is_set = False self.compute_type_is_set = False
...@@ -224,10 +235,28 @@ class Linear4bit(nn.Linear): ...@@ -224,10 +235,28 @@ class Linear4bit(nn.Linear):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.') warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training') warnings.filterwarnings('ignore', message='.*inference or training')
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
save weight and bias,
then fill state_dict with components of quant_state
"""
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
if getattr(self.weight, "quant_state", None) is not None:
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# Note: super()._load_from_state_dict() is not called here intentionally.
if self.bias is not None:
bias_data = state_dict.pop(prefix + "bias", None)
self.bias.data = bias_data.to(self.bias.data.device)
self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
state_dict, prefix=prefix + "weight" + ".", requires_grad=False
)
unexpected_keys.extend(state_dict.keys())
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually # weights are cast automatically as Int8Params, but the bias has to be cast manually
...@@ -270,7 +299,6 @@ class LinearNF4(Linear4bit): ...@@ -270,7 +299,6 @@ class LinearNF4(Linear4bit):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
class Int8Params(torch.nn.Parameter): class Int8Params(torch.nn.Parameter):
def __new__( def __new__(
cls, cls,
......
import json
import shlex import shlex
import subprocess import subprocess
import torch import torch
...@@ -158,3 +159,36 @@ def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_wei ...@@ -158,3 +159,36 @@ def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_wei
if func is not None: func(module) if func is not None: func(module)
return model return model
def pack_dict_to_tensor(source_dict):
"""
Pack a dictionary into a torch tensor for storing quant_state items in state_dict.
Parameters:
- source_dict: The dictionary to be packed.
Returns:
A torch tensor containing the packed data.
"""
json_str = json.dumps(source_dict)
json_bytes = json_str.encode('utf-8')
tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8)
return tensor_data
def unpack_tensor_to_dict(tensor_data):
"""
Unpack a torch tensor into a Python dictionary.
Parameters:
- tensor_data: The torch tensor containing the packed data.
Returns:
A Python dictionary containing the unpacked data.
"""
json_bytes = bytes(tensor_data.numpy())
json_str = json_bytes.decode('utf-8')
unpacked_dict = json.loads(json_str)
return unpacked_dict
import os
from contextlib import nullcontext
from itertools import product
from tempfile import TemporaryDirectory
import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.nn.modules import Linear4bit
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize(
"quant_type, compress_statistics, bias",
list(product(["nf4", "fp4"], [False, True], [False, True])),
)
def test_linear_serialization(quant_type, compress_statistics, bias):
original_dtype = torch.float16
compute_dtype = None
device = "cuda"
layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype) # original layer
# Quantizing original layer
linear_q = bnb.nn.Linear4bit(
linear.in_features,
linear.out_features,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device=device,
)
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False)
linear_q.weight = new_weight.to(device)
if bias:
linear_q.bias.data = linear.bias.data.to(device)
# saving to state_dict:
sd = linear_q.state_dict()
# creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit(
linear.in_features,
linear.out_features,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device=device, # TODO create on meta device to save loading time
)
# loading weights from state_dict:
linear_q2.load_state_dict(sd)
# MATCHING
a, b = linear_q.weight, linear_q2.weight
assert a.device == b.device
assert a.dtype == b.dtype
assert torch.equal(a, b)
q0 = a.quant_state
q1 = b.quant_state
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
c, d = getattr(q0, attr), getattr(q1, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
else:
assert c == d, f"{c} != {d}"
if q0.state2 is not None:
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
else:
assert c == d, f"{c} != {d}"
if bias:
a, b = linear_q.bias, linear_q2.bias
assert a.device == b.device
assert a.dtype == b.dtype
assert torch.equal(a, b)
# Forward test
x = torch.rand(42, layer_shape[0], device=device)
a = linear_q(x)
b = linear_q2(x)
assert a.device == b.device
assert a.dtype == b.dtype
assert torch.equal(a, b)
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
with TemporaryDirectory() as tmpdir:
state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
state_path = os.path.join(tmpdir, "state.pth")
torch.save(linear.state_dict(), state_path)
torch.save(linear_q.state_dict(), state_path_4bit)
size_orig, size_4 = os.path.getsize(state_path), os.path.getsize(
state_path_4bit
)
size_ratio = size_4 / size_orig
target_compression = 0.143 if original_dtype == torch.float32 else 0.285
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
assert size_ratio < target_compression, ratio_error_msg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment