Commit 61a4a20d authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

use QuantState class for quant_state

parent e812136c
......@@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=None):
def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
......@@ -504,7 +504,7 @@ class MatMul4Bit(torch.autograd.Function):
ctx.A = A
ctx.B = B
ctx.bias = bias
B_shape = state[1]
B_shape = quant_state.shape
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
......@@ -513,10 +513,10 @@ class MatMul4Bit(torch.autograd.Function):
# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias)
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
# 3. Save state
ctx.state = state
ctx.state = quant_state
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]):
......@@ -534,7 +534,6 @@ class MatMul4Bit(torch.autograd.Function):
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
A, B = ctx.tensors
state = ctx.state
grad_A, grad_B, grad_bias = None, None, None
......@@ -563,15 +562,14 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
if A.shape[-1] % blocksize != 0:
warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
if A.shape[-1] % quant_state.blocksize != 0:
warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
out = F.gemv_4bit(A, B.t(), out, quant_state=quant_state)
if bias is not None:
out += bias
return out
......
......@@ -566,6 +566,25 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out
class QuantState:
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax
self.shape = shape
self.code = code
self.dtype = dtype
self.blocksize = blocksize
self.quant_type = quant_type
self.offset = offset
self.state2 = state2
self.nested = state2 is not None
def to(self, device):
# make sure the quantization state is on the right device
self.absmax = self.absmax.to(device)
if self.nested:
self.offset = self.offset.to(device)
self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device)
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
"""
......@@ -633,16 +652,16 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
offset = absmax.mean()
absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
state = [qabsmax, code, blocksize, nested, A.dtype, offset, state2]
quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2)
else:
state = [absmax, code, blocksize, nested, A.dtype, None, None]
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype)
return out, state
return out, quant_state
def dequantize_blockwise(
A: Tensor,
quant_state: Tuple[Tensor, Tensor] = None,
quant_state: QuantState = None,
absmax: Tensor = None,
code: Tensor = None,
out: Tensor = None,
......@@ -659,8 +678,8 @@ def dequantize_blockwise(
----------
A : torch.Tensor
The input 8-bit tensor.
quant_state : tuple(torch.Tensor, torch.Tensor)
Tuple of code and absmax values.
quant_state : QuantState
Object with code, absmax and other quantization state components.
absmax : torch.Tensor
The absmax values.
code : torch.Tensor
......@@ -681,36 +700,35 @@ def dequantize_blockwise(
code = name2qmap["dynamic"]
if quant_state is None:
quant_state = (absmax, code, blocksize, False, torch.float32, None, None)
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32)
if nested:
absmax = dequantize_blockwise(absmax, state2)
absmax += offset
absmax = quant_state.absmax
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32: absmax = absmax.float()
if out is None:
out = torch.empty(A.shape, dtype=dtype, device=A.device)
out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device)
if A.device.type != 'cpu':
device = pre_call(A.device)
code = code.to(A.device)
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
code = quant_state.code.to(A.device)
if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
is_on_gpu([A, absmax, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
else:
code = code.cpu()
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
code = quant_state.code.cpu()
lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel()))
return out
......@@ -839,26 +857,26 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
datatype = get_4bit_type(quant_type, device=A.device)
code = get_4bit_type(quant_type, device=A.device)
if compress_statistics:
offset = absmax.mean()
absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
del absmax
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype]
state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2)
else:
state = [absmax, input_shape, A.dtype, blocksize, None, quant_type, datatype]
state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, )
return out, state
def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4')
def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4')
def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
"""
Dequantizes FP4 blockwise quantized values.
......@@ -868,8 +886,8 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
----------
A : torch.Tensor
The input 8-bit tensor (packed 4-bit values).
quant_state : tuple(torch.Tensor, torch.Size, torch.dtype)
Tuple of absmax values, original tensor shape and original dtype.
quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
......@@ -892,41 +910,40 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
if quant_state is None:
assert absmax is not None and out is not None
shape = out.shape
dtype = out.dtype
quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type)
else:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
absmax = quant_state.absmax
if compressed_stats is not None:
offset, state2 = compressed_stats
absmax = dequantize_blockwise(absmax, state2)
absmax += offset
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32: absmax = absmax.float()
if out is None:
out = torch.empty(shape, dtype=dtype, device=A.device)
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
n = out.numel()
device = pre_call(A.device)
is_on_gpu([A, absmax, out])
if out.dtype == torch.float32:
if quant_type == 'fp4':
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else:
lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
elif out.dtype == torch.float16:
if quant_type == 'fp4':
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else:
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
elif out.dtype == torch.bfloat16:
if quant_type == 'fp4':
lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else:
lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
......@@ -952,22 +969,22 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
def dequantize(
A: Tensor,
quant_state: Tuple[Tensor, Tensor] = None,
state: Tuple[Tensor, Tensor] = None,
absmax: Tensor = None,
code: Tensor = None,
out: Tensor = None,
) -> Tensor:
assert quant_state is not None or absmax is not None
if code is None and quant_state is None:
assert state is not None or absmax is not None
if code is None and state is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
code = code.to(A.device)
if quant_state is None:
quant_state = (absmax, code)
out = dequantize_no_absmax(A, quant_state[1], out)
return out * quant_state[0]
if state is None:
state = (absmax, code)
out = dequantize_no_absmax(A, state[1], out)
return out * state[0]
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
......@@ -1472,23 +1489,22 @@ def gemv_4bit(
out: Tensor = None,
transposed_A=False,
transposed_B=False,
state=None
quant_state=None
):
prev_device = pre_call(A.device)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None:
if quant_state is None:
raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
if A.numel() != A.shape[-1]:
raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
Bshape = state[1]
Bshape = quant_state.shape
bout = Bshape[0]
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state
if compressed_stats is not None:
offset, state2 = compressed_stats
absmax = dequantize_blockwise(absmax, state2)
absmax += offset
absmax = quant_state.absmax
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if out is None:
if len(A.shape) == 3:
......@@ -1502,7 +1518,7 @@ def gemv_4bit(
lda = Bshape[0]
ldc = Bshape[0]
ldb = (A.shape[-1]+1)//2
is_on_gpu([B, A, out, absmax, state[-1]])
is_on_gpu([B, A, out, absmax, quant_state.code])
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
......@@ -1512,11 +1528,11 @@ def gemv_4bit(
if B.dtype == torch.uint8:
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
elif A.dtype == torch.float32:
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
......@@ -1798,7 +1814,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
def mm_dequant(
A,
quant_state,
state,
row_stats,
col_stats,
out=None,
......@@ -1808,7 +1824,7 @@ def mm_dequant(
):
assert A.dtype == torch.int32
if bias is not None: assert bias.dtype == torch.float16
out_shape = quant_state[0]
out_shape = state[0]
if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
......
......@@ -178,22 +178,9 @@ class Params4bit(torch.nn.Parameter):
if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"):
return self.cuda(device)
else:
s = self.quant_state
if s is not None:
# make sure the quantization state is on the right device
s[0] = s[0].to(device)
if self.compress_statistics:
# TODO: refactor this. This is a nightmare
# for 4-bit:
# state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
# state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
#s[-2][0] = s[-2][0].to(device) # offset
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
# for 8-bit
s[-3][0] = s[-3][0].to(device) # offset
s[-3][1][0] = s[-3][1][0].to(device) # nested quantiation state statitics
s[-3][1][1] = s[-3][1][1].to(device) # nested quantiation codebook
if self.quant_state is not None:
self.quant_state.to(device)
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad, quant_state=self.quant_state,
blocksize=self.blocksize, compress_statistics=self.compress_statistics,
......@@ -224,11 +211,6 @@ class Linear4bit(nn.Linear):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training')
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
......@@ -270,7 +252,6 @@ class LinearNF4(Linear4bit):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
class Int8Params(torch.nn.Parameter):
def __new__(
cls,
......
......@@ -2401,7 +2401,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state)
C2 = F.gemv_4bit(A, qB.t(), quant_state=state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment