Commit 4c11d6dc authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

reverted fn signatures in functional()

parent 1d9f0f2a
...@@ -569,7 +569,7 @@ def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = N ...@@ -569,7 +569,7 @@ def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = N
warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
return MatMul4Bit.apply(A, B, out, bias, quant_state) return MatMul4Bit.apply(A, B, out, bias, quant_state)
else: else:
out = F.gemv_4bit(A, B.t(), out, quant_state=quant_state) out = F.gemv_4bit(A, B.t(), out, state=quant_state)
if bias is not None: if bias is not None:
out += bias out += bias
return out return out
......
...@@ -1579,22 +1579,22 @@ def gemv_4bit( ...@@ -1579,22 +1579,22 @@ def gemv_4bit(
out: Tensor = None, out: Tensor = None,
transposed_A=False, transposed_A=False,
transposed_B=False, transposed_B=False,
quant_state=None state=None
): ):
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if quant_state is None: if state is None:
raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
if A.numel() != A.shape[-1]: if A.numel() != A.shape[-1]:
raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
Bshape = quant_state.shape Bshape = state.shape
bout = Bshape[0] bout = Bshape[0]
absmax = quant_state.absmax absmax = state.absmax
if quant_state.nested: if state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax = dequantize_blockwise(state.absmax, state.state2)
absmax += quant_state.offset absmax += state.offset
if out is None: if out is None:
if len(A.shape) == 3: if len(A.shape) == 3:
...@@ -1608,7 +1608,7 @@ def gemv_4bit( ...@@ -1608,7 +1608,7 @@ def gemv_4bit(
lda = Bshape[0] lda = Bshape[0]
ldc = Bshape[0] ldc = Bshape[0]
ldb = (A.shape[-1]+1)//2 ldb = (A.shape[-1]+1)//2
is_on_gpu([B, A, out, absmax, quant_state.code]) is_on_gpu([B, A, out, absmax, state.code])
m = ct.c_int32(m) m = ct.c_int32(m)
n = ct.c_int32(n) n = ct.c_int32(n)
k = ct.c_int32(k) k = ct.c_int32(k)
...@@ -1618,11 +1618,11 @@ def gemv_4bit( ...@@ -1618,11 +1618,11 @@ def gemv_4bit(
if B.dtype == torch.uint8: if B.dtype == torch.uint8:
if A.dtype == torch.float16: if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize)) lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
elif A.dtype == torch.bfloat16: elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize)) lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
elif A.dtype == torch.float32: elif A.dtype == torch.float32:
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize)) lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
else: else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
...@@ -1904,7 +1904,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ...@@ -1904,7 +1904,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
def mm_dequant( def mm_dequant(
A, A,
state, quant_state,
row_stats, row_stats,
col_stats, col_stats,
out=None, out=None,
...@@ -1914,7 +1914,7 @@ def mm_dequant( ...@@ -1914,7 +1914,7 @@ def mm_dequant(
): ):
assert A.dtype == torch.int32 assert A.dtype == torch.int32
if bias is not None: assert bias.dtype == torch.float16 if bias is not None: assert bias.dtype == torch.float16
out_shape = state[0] out_shape = quant_state[0]
if len(out_shape) == 3: if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2]) out_shape = (out_shape[0] * out_shape[1], out_shape[2])
......
...@@ -2401,7 +2401,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind): ...@@ -2401,7 +2401,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t()) C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), quant_state=state) C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state) C1 = bnb.matmul_4bit(A, qB.t(), state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment