Commit 61a4a20d authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

use QuantState class for quant_state

parent e812136c
...@@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, state=None): def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None):
# default of pytorch behavior if inputs are empty # default of pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
...@@ -504,7 +504,7 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -504,7 +504,7 @@ class MatMul4Bit(torch.autograd.Function):
ctx.A = A ctx.A = A
ctx.B = B ctx.B = B
ctx.bias = bias ctx.bias = bias
B_shape = state[1] B_shape = quant_state.shape
if A.shape[-1] == B_shape[0]: if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else: else:
...@@ -513,10 +513,10 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -513,10 +513,10 @@ class MatMul4Bit(torch.autograd.Function):
# 1. Dequantize # 1. Dequantize
# 2. MatmulnN # 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias) output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
# 3. Save state # 3. Save state
ctx.state = state ctx.state = quant_state
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]): if any(ctx.needs_input_grad[:2]):
...@@ -534,7 +534,6 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -534,7 +534,6 @@ class MatMul4Bit(torch.autograd.Function):
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
A, B = ctx.tensors A, B = ctx.tensors
state = ctx.state
grad_A, grad_B, grad_bias = None, None, None grad_A, grad_B, grad_bias = None, None, None
...@@ -563,15 +562,14 @@ def matmul( ...@@ -563,15 +562,14 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None):
assert quant_state is not None assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False: if A.numel() == A.shape[-1] and A.requires_grad == False:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state if A.shape[-1] % quant_state.blocksize != 0:
if A.shape[-1] % blocksize != 0: warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
return MatMul4Bit.apply(A, B, out, bias, quant_state) return MatMul4Bit.apply(A, B, out, bias, quant_state)
else: else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state) out = F.gemv_4bit(A, B.t(), out, quant_state=quant_state)
if bias is not None: if bias is not None:
out += bias out += bias
return out return out
......
...@@ -566,6 +566,25 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n ...@@ -566,6 +566,25 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out return out
class QuantState:
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax
self.shape = shape
self.code = code
self.dtype = dtype
self.blocksize = blocksize
self.quant_type = quant_type
self.offset = offset
self.state2 = state2
self.nested = state2 is not None
def to(self, device):
# make sure the quantization state is on the right device
self.absmax = self.absmax.to(device)
if self.nested:
self.offset = self.offset.to(device)
self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device)
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
""" """
...@@ -633,16 +652,16 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou ...@@ -633,16 +652,16 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
offset = absmax.mean() offset = absmax.mean()
absmax -= offset absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
state = [qabsmax, code, blocksize, nested, A.dtype, offset, state2] quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2)
else: else:
state = [absmax, code, blocksize, nested, A.dtype, None, None] quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype)
return out, state return out, quant_state
def dequantize_blockwise( def dequantize_blockwise(
A: Tensor, A: Tensor,
quant_state: Tuple[Tensor, Tensor] = None, quant_state: QuantState = None,
absmax: Tensor = None, absmax: Tensor = None,
code: Tensor = None, code: Tensor = None,
out: Tensor = None, out: Tensor = None,
...@@ -659,8 +678,8 @@ def dequantize_blockwise( ...@@ -659,8 +678,8 @@ def dequantize_blockwise(
---------- ----------
A : torch.Tensor A : torch.Tensor
The input 8-bit tensor. The input 8-bit tensor.
quant_state : tuple(torch.Tensor, torch.Tensor) quant_state : QuantState
Tuple of code and absmax values. Object with code, absmax and other quantization state components.
absmax : torch.Tensor absmax : torch.Tensor
The absmax values. The absmax values.
code : torch.Tensor code : torch.Tensor
...@@ -681,36 +700,35 @@ def dequantize_blockwise( ...@@ -681,36 +700,35 @@ def dequantize_blockwise(
code = name2qmap["dynamic"] code = name2qmap["dynamic"]
if quant_state is None: if quant_state is None:
quant_state = (absmax, code, blocksize, False, torch.float32, None, None) quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32)
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state absmax = quant_state.absmax
if quant_state.nested:
if nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax = dequantize_blockwise(absmax, state2) absmax += quant_state.offset
absmax += offset
if absmax.dtype != torch.float32: absmax = absmax.float() if absmax.dtype != torch.float32: absmax = absmax.float()
if out is None: if out is None:
out = torch.empty(A.shape, dtype=dtype, device=A.device) out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device)
if A.device.type != 'cpu': if A.device.type != 'cpu':
device = pre_call(A.device) device = pre_call(A.device)
code = code.to(A.device) code = quant_state.code.to(A.device)
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
is_on_gpu([A, absmax, out]) is_on_gpu([A, absmax, out])
if out.dtype == torch.float32: if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16: elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.bfloat16: elif out.dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device) post_call(A.device)
else: else:
code = code.cpu() code = quant_state.code.cpu()
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel()))
return out return out
...@@ -839,26 +857,26 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz ...@@ -839,26 +857,26 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device) post_call(A.device)
datatype = get_4bit_type(quant_type, device=A.device) code = get_4bit_type(quant_type, device=A.device)
if compress_statistics: if compress_statistics:
offset = absmax.mean() offset = absmax.mean()
absmax -= offset absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
del absmax del absmax
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype] state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2)
else: else:
state = [absmax, input_shape, A.dtype, blocksize, None, quant_type, datatype] state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, )
return out, state return out, state
def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4')
def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4')
def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
""" """
Dequantizes FP4 blockwise quantized values. Dequantizes FP4 blockwise quantized values.
...@@ -868,8 +886,8 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: ...@@ -868,8 +886,8 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
---------- ----------
A : torch.Tensor A : torch.Tensor
The input 8-bit tensor (packed 4-bit values). The input 8-bit tensor (packed 4-bit values).
quant_state : tuple(torch.Tensor, torch.Size, torch.dtype) quant_state : QuantState
Tuple of absmax values, original tensor shape and original dtype. object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor absmax : torch.Tensor
The absmax values. The absmax values.
out : torch.Tensor out : torch.Tensor
...@@ -892,41 +910,40 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: ...@@ -892,41 +910,40 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
if quant_state is None: if quant_state is None:
assert absmax is not None and out is not None assert absmax is not None and out is not None
shape = out.shape
dtype = out.dtype quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type)
else: else:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state absmax = quant_state.absmax
if compressed_stats is not None: if quant_state.nested:
offset, state2 = compressed_stats absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax = dequantize_blockwise(absmax, state2) absmax += quant_state.offset
absmax += offset
if absmax.dtype != torch.float32: absmax = absmax.float() if absmax.dtype != torch.float32: absmax = absmax.float()
if out is None: if out is None:
out = torch.empty(shape, dtype=dtype, device=A.device) out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
n = out.numel() n = out.numel()
device = pre_call(A.device) device = pre_call(A.device)
is_on_gpu([A, absmax, out]) is_on_gpu([A, absmax, out])
if out.dtype == torch.float32: if out.dtype == torch.float32:
if quant_type == 'fp4': if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else: else:
lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
elif out.dtype == torch.float16: elif out.dtype == torch.float16:
if quant_type == 'fp4': if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else: else:
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
elif out.dtype == torch.bfloat16: elif out.dtype == torch.bfloat16:
if quant_type == 'fp4': if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else: else:
lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device) post_call(A.device)
...@@ -952,22 +969,22 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: ...@@ -952,22 +969,22 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
def dequantize( def dequantize(
A: Tensor, A: Tensor,
quant_state: Tuple[Tensor, Tensor] = None, state: Tuple[Tensor, Tensor] = None,
absmax: Tensor = None, absmax: Tensor = None,
code: Tensor = None, code: Tensor = None,
out: Tensor = None, out: Tensor = None,
) -> Tensor: ) -> Tensor:
assert quant_state is not None or absmax is not None assert state is not None or absmax is not None
if code is None and quant_state is None: if code is None and state is None:
if "dynamic" not in name2qmap: if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device) name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"] code = name2qmap["dynamic"]
code = code.to(A.device) code = code.to(A.device)
if quant_state is None: if state is None:
quant_state = (absmax, code) state = (absmax, code)
out = dequantize_no_absmax(A, quant_state[1], out) out = dequantize_no_absmax(A, state[1], out)
return out * quant_state[0] return out * state[0]
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
...@@ -1472,23 +1489,22 @@ def gemv_4bit( ...@@ -1472,23 +1489,22 @@ def gemv_4bit(
out: Tensor = None, out: Tensor = None,
transposed_A=False, transposed_A=False,
transposed_B=False, transposed_B=False,
state=None quant_state=None
): ):
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None: if quant_state is None:
raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
if A.numel() != A.shape[-1]: if A.numel() != A.shape[-1]:
raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
Bshape = state[1] Bshape = quant_state.shape
bout = Bshape[0] bout = Bshape[0]
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state absmax = quant_state.absmax
if compressed_stats is not None: if quant_state.nested:
offset, state2 = compressed_stats absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax = dequantize_blockwise(absmax, state2) absmax += quant_state.offset
absmax += offset
if out is None: if out is None:
if len(A.shape) == 3: if len(A.shape) == 3:
...@@ -1502,7 +1518,7 @@ def gemv_4bit( ...@@ -1502,7 +1518,7 @@ def gemv_4bit(
lda = Bshape[0] lda = Bshape[0]
ldc = Bshape[0] ldc = Bshape[0]
ldb = (A.shape[-1]+1)//2 ldb = (A.shape[-1]+1)//2
is_on_gpu([B, A, out, absmax, state[-1]]) is_on_gpu([B, A, out, absmax, quant_state.code])
m = ct.c_int32(m) m = ct.c_int32(m)
n = ct.c_int32(n) n = ct.c_int32(n)
k = ct.c_int32(k) k = ct.c_int32(k)
...@@ -1512,11 +1528,11 @@ def gemv_4bit( ...@@ -1512,11 +1528,11 @@ def gemv_4bit(
if B.dtype == torch.uint8: if B.dtype == torch.uint8:
if A.dtype == torch.float16: if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
elif A.dtype == torch.bfloat16: elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
elif A.dtype == torch.float32: elif A.dtype == torch.float32:
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
else: else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
...@@ -1798,7 +1814,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ...@@ -1798,7 +1814,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
def mm_dequant( def mm_dequant(
A, A,
quant_state, state,
row_stats, row_stats,
col_stats, col_stats,
out=None, out=None,
...@@ -1808,7 +1824,7 @@ def mm_dequant( ...@@ -1808,7 +1824,7 @@ def mm_dequant(
): ):
assert A.dtype == torch.int32 assert A.dtype == torch.int32
if bias is not None: assert bias.dtype == torch.float16 if bias is not None: assert bias.dtype == torch.float16
out_shape = quant_state[0] out_shape = state[0]
if len(out_shape) == 3: if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2]) out_shape = (out_shape[0] * out_shape[1], out_shape[2])
......
...@@ -178,22 +178,9 @@ class Params4bit(torch.nn.Parameter): ...@@ -178,22 +178,9 @@ class Params4bit(torch.nn.Parameter):
if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"):
return self.cuda(device) return self.cuda(device)
else: else:
s = self.quant_state if self.quant_state is not None:
if s is not None: self.quant_state.to(device)
# make sure the quantization state is on the right device
s[0] = s[0].to(device)
if self.compress_statistics:
# TODO: refactor this. This is a nightmare
# for 4-bit:
# state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
# state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
#s[-2][0] = s[-2][0].to(device) # offset
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
# for 8-bit
s[-3][0] = s[-3][0].to(device) # offset
s[-3][1][0] = s[-3][1][0].to(device) # nested quantiation state statitics
s[-3][1][1] = s[-3][1][1].to(device) # nested quantiation codebook
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad, quant_state=self.quant_state, requires_grad=self.requires_grad, quant_state=self.quant_state,
blocksize=self.blocksize, compress_statistics=self.compress_statistics, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
...@@ -224,11 +211,6 @@ class Linear4bit(nn.Linear): ...@@ -224,11 +211,6 @@ class Linear4bit(nn.Linear):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.') warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training') warnings.filterwarnings('ignore', message='.*inference or training')
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually # weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype: if self.bias is not None and self.bias.dtype != x.dtype:
...@@ -270,7 +252,6 @@ class LinearNF4(Linear4bit): ...@@ -270,7 +252,6 @@ class LinearNF4(Linear4bit):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
class Int8Params(torch.nn.Parameter): class Int8Params(torch.nn.Parameter):
def __new__( def __new__(
cls, cls,
......
...@@ -2401,7 +2401,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind): ...@@ -2401,7 +2401,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t()) C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state) C2 = F.gemv_4bit(A, qB.t(), quant_state=state)
A.requires_grad = True A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state) C1 = bnb.matmul_4bit(A, qB.t(), state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment