"csrc/pythonInterface.cpp" did not exist on "3aef78342aec4fff1922c0c2cdd83bdda928b536"
Commit 0f0390ac authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added double quantization support and tests.

parent 94168d79
...@@ -1461,16 +1461,25 @@ def gemv_4bit( ...@@ -1461,16 +1461,25 @@ def gemv_4bit(
transposed_B=False, transposed_B=False,
state=None state=None
): ):
prev_device = pre_call(A.device)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None: if state is None:
Bshape = B.shape raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
bout = Bshape[1]
else: Bshape = state[1]
Bshape = state[1] bout = Bshape[0]
bout = Bshape[0] absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state
if compressed_stats is not None:
offset, state2 = compressed_stats
absmax = dequantize_blockwise(absmax, state2)
absmax += offset
if out is None: if out is None:
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
sA = A.shape sA = A.shape
sB = B.shape sB = B.shape
if transposed_A and len(sA) == 2: if transposed_A and len(sA) == 2:
...@@ -1557,14 +1566,16 @@ def gemv_4bit( ...@@ -1557,14 +1566,16 @@ def gemv_4bit(
if B.dtype == torch.uint8: if B.dtype == torch.uint8:
if A.dtype == torch.float16: if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
elif A.dtype == torch.bfloat16: elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
else: else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
else: else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
post_call(prev_device)
return out return out
......
...@@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ...@@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print("partial matmul", time.time() - t0) print("partial matmul", time.time() - t0)
batch_size = 1 batch_size = 5
seqdim = 1 seqdim = 1
values = [] values = []
#values.append((batch_size, seqdim, 768, 4 * 768)) #values.append((batch_size, seqdim, 768, 4 * 768))
...@@ -1786,8 +1786,8 @@ values = [] ...@@ -1786,8 +1786,8 @@ values = []
#values.append((batch_size, seqdim, 2560, 4*2560)) #values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096)) #values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5120, 4*5120)) #values.append((batch_size, seqdim, 5120, 4*5120))
#values.append((batch_size, seqdim, 6656, 4*6656)) values.append((batch_size, seqdim, 6656, 4*6656))
values.append((batch_size, seqdim, 8192, 4*8192)) #values.append((batch_size, seqdim, 8192, 4*8192))
#values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288)) #values.append((batch_size, seqdim, 12288, 4*12288))
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
...@@ -1804,6 +1804,7 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1804,6 +1804,7 @@ def test_bench_matmul(batch, seq, model, hidden):
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True) B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
B_nf4, state_nf4 = F.quantize_nf4(B) B_nf4, state_nf4 = F.quantize_nf4(B)
B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half() linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
linear8bit.eval() linear8bit.eval()
...@@ -1816,7 +1817,7 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1816,7 +1817,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
F.gemv_4bit(A, B_nf4.t(), state=state_nf4) bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
# warmup # warmup
for i in range(iters): for i in range(iters):
...@@ -1848,11 +1849,18 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1848,11 +1849,18 @@ def test_bench_matmul(batch, seq, model, hidden):
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(iters): for i in range(iters):
#bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
F.gemv_4bit(A, B_nf4.t(), state=state_nf4)
torch.cuda.synchronize() torch.cuda.synchronize()
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
torch.cuda.synchronize()
print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#torch.cuda.synchronize() #torch.cuda.synchronize()
#t0 = time.time() #t0 = time.time()
#for i in range(iters): #for i in range(iters):
...@@ -2351,11 +2359,12 @@ def test_normal_map_tree(): ...@@ -2351,11 +2359,12 @@ def test_normal_map_tree():
print(pivots) print(pivots)
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16'])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16']) #@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
def test_gemv_4bit(dtype, storage_type): def test_gemv_4bit(dtype, storage_type, double_quant):
print('') print('')
for dim in [128, 256, 512, 1024, 2048, 4096]: for dim in [128, 256, 512, 1024, 2048, 4096]:
#for dim in [4*1024]: #for dim in [4*1024]:
...@@ -2365,7 +2374,7 @@ def test_gemv_4bit(dtype, storage_type): ...@@ -2365,7 +2374,7 @@ def test_gemv_4bit(dtype, storage_type):
max_err = 0 max_err = 0
max_relerr = 0 max_relerr = 0
for i in range(1): for i in range(100):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda') #A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda') #A = torch.rand(1, 4096, dtype=dtype, device='cuda')
...@@ -2382,11 +2391,11 @@ def test_gemv_4bit(dtype, storage_type): ...@@ -2382,11 +2391,11 @@ def test_gemv_4bit(dtype, storage_type):
#A.flatten()[:-1] = 0 #A.flatten()[:-1] = 0
#B.flatten()[:-1] = 0 #B.flatten()[:-1] = 0
qB, state = F.quantize_4bit(B, quant_type=storage_type) qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
F.dequantize_4bit(qB, state) #F.dequantize_4bit(qB, state)
C2 = F.gemv_4bit(A, qB.t(), state=state)
C3 = torch.matmul(A, B.t()) C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state) C1 = bnb.matmul_4bit(A, qB.t(), state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment