Commit 4c11d6dc authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

reverted fn signatures in functional()

parent 1d9f0f2a
......@@ -569,7 +569,7 @@ def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = N
warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
out = F.gemv_4bit(A, B.t(), out, quant_state=quant_state)
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
if bias is not None:
out += bias
return out
......
......@@ -1579,22 +1579,22 @@ def gemv_4bit(
out: Tensor = None,
transposed_A=False,
transposed_B=False,
quant_state=None
state=None
):
prev_device = pre_call(A.device)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if quant_state is None:
if state is None:
raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
if A.numel() != A.shape[-1]:
raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
Bshape = quant_state.shape
Bshape = state.shape
bout = Bshape[0]
absmax = quant_state.absmax
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
absmax = state.absmax
if state.nested:
absmax = dequantize_blockwise(state.absmax, state.state2)
absmax += state.offset
if out is None:
if len(A.shape) == 3:
......@@ -1608,7 +1608,7 @@ def gemv_4bit(
lda = Bshape[0]
ldc = Bshape[0]
ldb = (A.shape[-1]+1)//2
is_on_gpu([B, A, out, absmax, quant_state.code])
is_on_gpu([B, A, out, absmax, state.code])
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
......@@ -1618,11 +1618,11 @@ def gemv_4bit(
if B.dtype == torch.uint8:
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
elif A.dtype == torch.float32:
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
......@@ -1904,7 +1904,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
def mm_dequant(
A,
state,
quant_state,
row_stats,
col_stats,
out=None,
......@@ -1914,7 +1914,7 @@ def mm_dequant(
):
assert A.dtype == torch.int32
if bias is not None: assert bias.dtype == torch.float16
out_shape = state[0]
out_shape = quant_state[0]
if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
......
......@@ -2401,7 +2401,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), quant_state=state)
C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment