Commit 98cbc4bc authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added k-bit fp8 map.

parent caf18325
...@@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8): ...@@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8):
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
e = exponent_bits e = exponent_bits
p = precision_bits p = precision_bits
assert e+p == 7 has_sign = 1 if signed else 0
assert e+p == total_bits-has_sign
# the exponent is biased to 2^(e-1) -1 == 0 # the exponent is biased to 2^(e-1) -1 == 0
evalues = [] evalues = []
pvalues = [] pvalues = []
for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)): for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
evalues.append(2**val) evalues.append(2**val)
...@@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): ...@@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
value += pval*(2**-(i+1)) value += pval*(2**-(i+1))
pvalues.append(value) pvalues.append(value)
assert len(evalues)*len(pvalues) == 128 assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
values = [] values = []
for ev in evalues: for ev in evalues:
for pv in pvalues: for pv in pvalues:
values.append(-ev*pv) if signed:
values.append(-ev*pv)
values.append(ev*pv) values.append(ev*pv)
if total_bits < 8:
gap = 256 - len(values)
for i in range(gap):
values.append(0)
values.sort() values.sort()
code = torch.Tensor(values) code = torch.Tensor(values)
code /= code.max() code /= code.max()
......
...@@ -11,7 +11,7 @@ import bitsandbytes as bnb ...@@ -11,7 +11,7 @@ import bitsandbytes as bnb
from bitsandbytes import functional as F from bitsandbytes import functional as F
torch.set_printoptions( torch.set_printoptions(
precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
) )
k = 20 k = 20
...@@ -2095,49 +2095,43 @@ def test_fp8_quant(): ...@@ -2095,49 +2095,43 @@ def test_fp8_quant():
def test_few_bit_quant(): def test_few_bit_quant():
for bits in range(2, 9): for bits in range(2, 9):
code = F.create_linear_map(True, bits=bits).cuda() for method in ['linear', 'fp8']:
assert code.numel() == 256 code = None
print(bits) if method == 'linear':
for i in range(100): code = F.create_linear_map(True, bits=bits).cuda()
elif method == 'fp8':
values = torch.randn(1, 24, device='cuda') ebits = math.ceil(bits/2)
values /= values.abs().max() pbits = bits-ebits-1
#values[values.abs() < 1e-6] += 1e-5 code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
print(ebits, pbits, bits)
q1 = [] print(code)
v1 = [] assert code.numel() == 256
for v in values[0]: print(bits)
idx = torch.abs(v-code).argmin() for i in range(10):
q1.append(idx.item())
v1.append(code[idx].item()) values = torch.randn(1, 32, device='cuda')
values /= values.abs().max()
q1 = torch.Tensor(q1).cuda() #values[values.abs() < 1e-6] += 1e-5
v1 = torch.Tensor(v1).cuda()
q1 = []
q2, S2 = F.quantize(values, code=code) v1 = []
v2 = F.dequantize(q2, S2) for v in values[0]:
idx = torch.abs(v-code).argmin()
idx = torch.isclose(q1.int(), q2.int()) q1.append(idx.item())
if idx.sum(): v1.append(code[idx].item())
# some weird cases
err1 = torch.abs(v1-values).mean() q1 = torch.Tensor(q1).cuda()
err2 = torch.abs(v2-values).mean() v1 = torch.Tensor(v1).cuda()
assert err2 <= err1
q2, S2 = F.quantize(values, code=code)
else: v2 = F.dequantize(q2, S2)
torch.testing.assert_allclose(q1, q2)
idx = torch.isclose(q1.int(), q2.int())
#print(e_bits, p_bits) if idx.sum():
#abserr = [] # some weird cases
#relerr = [] err1 = torch.abs(v1-values).mean()
#for i in range(100): err2 = torch.abs(v2-values).mean()
# A1 = torch.randn(1024, 1024, device="cuda") assert err2 <= err1
# C, SC = F.quantize_blockwise(A1, code=code)
# A2 = F.dequantize_blockwise(C, SC) else:
# diff = torch.abs(A1 - A2) torch.testing.assert_allclose(q1, q2)
# reldiff = diff/torch.abs(A1+1e-8)
# abserr.append(diff.mean().item())
# relerr.append(reldiff.mean().item())
# #assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment