Commit c361f842 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Fixed matmul_fp4 transpose.

parent cfe4705e
...@@ -496,7 +496,7 @@ class MatMulFP4(torch.autograd.Function): ...@@ -496,7 +496,7 @@ class MatMulFP4(torch.autograd.Function):
# 1. Dequantize # 1. Dequantize
# 2. MatmulnN # 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias) output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)
# 3. Save state # 3. Save state
ctx.state = state ctx.state = state
...@@ -531,7 +531,7 @@ class MatMulFP4(torch.autograd.Function): ...@@ -531,7 +531,7 @@ class MatMulFP4(torch.autograd.Function):
# not supported by PyTorch. TODO: create work-around # not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A) #if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A)) if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A).t())
return grad_A, grad_B, None, grad_bias, None return grad_A, grad_B, None, grad_bias, None
......
...@@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, ...@@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
if not transpose[0] and transpose[1]: if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t()) out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B2, quant_state, bias=bias2) out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
elif not transpose[0] and not transpose[1]: elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B) out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
if has_bias: if has_bias:
out_torch += bias out_torch += bias
......
...@@ -1835,7 +1835,7 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1835,7 +1835,7 @@ def test_bench_matmul(batch, seq, model, hidden):
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(iters): for i in range(iters):
bnb.matmul_fp4(A, B_fp4, quant_state=state) bnb.matmul_fp4(A, B_fp4.t(), quant_state=state)
torch.cuda.synchronize() torch.cuda.synchronize()
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment