Commit cbfdf0b5 authored by justheuristic's avatar justheuristic
Browse files

cast edge case

parent e35e2c66
...@@ -212,9 +212,9 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -212,9 +212,9 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.B = B ctx.B = B
ctx.bias = bias ctx.bias = bias
if A.shape[-1] == B.shape[0]: if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
else: else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A # 1. Quantize A
# 2. Quantize B # 2. Quantize B
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment