Commit 1efb87d8 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added FP8 quantization map.

parent 8d87c0b8
...@@ -6,6 +6,7 @@ import ctypes as ct ...@@ -6,6 +6,7 @@ import ctypes as ct
import operator import operator
import random import random
import torch import torch
import itertools
from typing import Tuple from typing import Tuple
from torch import Tensor from torch import Tensor
...@@ -136,6 +137,39 @@ def create_linear_map(signed=True): ...@@ -136,6 +137,39 @@ def create_linear_map(signed=True):
return torch.linspace(0.0, 1.0, 256) return torch.linspace(0.0, 1.0, 256)
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
e = exponent_bits
p = precision_bits
assert e+p == 7
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
pvalues = []
for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
evalues.append(2**val)
lst = list(itertools.product([0, 1], repeat=precision_bits))
for bit_pattern in lst:
value = 1
for i, pval in enumerate(list(bit_pattern)):
value += pval*(2**-(i+1))
pvalues.append(value)
assert len(evalues)*len(pvalues) == 128
values = []
for ev in evalues:
for pv in pvalues:
values.append(-ev*pv)
values.append(ev*pv)
values.sort()
code = torch.Tensor(values)
code /= code.max()
code[127] = 0
return code
def create_dynamic_map(signed=True, n=7): def create_dynamic_map(signed=True, n=7):
""" """
Creates the dynamic quantiztion map. Creates the dynamic quantiztion map.
......
...@@ -2040,3 +2040,54 @@ def test_blockwise_cpu_large(): ...@@ -2040,3 +2040,54 @@ def test_blockwise_cpu_large():
assert diffs[-1] < 0.011 assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs)) # print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs)) # print(sum(reldiffs)/len(reldiffs))
def test_fp8_quant():
for e_bits in range(1, 7):
p_bits = 7-e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
print(e_bits, p_bits)
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
print(sum(abserr)/len(abserr))
print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
print(sum(abserr)/len(abserr))
print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
print(3, sum(abserr)/len(abserr))
print(3, sum(relerr)/len(relerr))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment