Commit 0d332a64 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added normal with extra value.

parent 2dd5d690
...@@ -9,7 +9,7 @@ import random ...@@ -9,7 +9,7 @@ import random
import torch import torch
import itertools import itertools
import math import math
import scipy.stats from scipy.stats import norm
import numpy as np import numpy as np
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
...@@ -181,7 +181,7 @@ def create_custom_map(seed=0, scale=0.01): ...@@ -181,7 +181,7 @@ def create_custom_map(seed=0, scale=0.01):
#v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207 #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207
#v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30 #v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30
#v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293 #v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293
#v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88 v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88
# 7B evo start # 7B evo start
#v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06 #v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06
...@@ -197,9 +197,7 @@ def create_custom_map(seed=0, scale=0.01): ...@@ -197,9 +197,7 @@ def create_custom_map(seed=0, scale=0.01):
#v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237] #v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237]
# theoretically optiomal (0.93333) # theoretically optiomal (0.93333)
v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333 #v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333
if seed > 0: if seed > 0:
v = np.array(v) v = np.array(v)
...@@ -220,6 +218,26 @@ def create_custom_map(seed=0, scale=0.01): ...@@ -220,6 +218,26 @@ def create_custom_map(seed=0, scale=0.01):
assert values.numel() == 256 assert values.numel() == 256
return values return values
def create_normal_map(offset=0.966666, use_extra_value=True):
if use_extra_value:
# one more positive value, this is an asymmetric type
v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
v2 = [0]*(256-15) ## we have 15 non-zero values in this data type
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
v = v1 + v2 + v3
else:
v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist()
v2 = [0]*(256-14) ## we have 14 non-zero values in this data type
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
v = v1 + v2 + v3
values = torch.Tensor(v)
values = values.sort().values
values /= values.max()
assert values.numel() == 256
return values
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
e = exponent_bits e = exponent_bits
p = precision_bits p = precision_bits
......
...@@ -2318,6 +2318,3 @@ def test_bench_fp4_dequant(): ...@@ -2318,6 +2318,3 @@ def test_bench_fp4_dequant():
# torch.matmul(b, a.t()) # torch.matmul(b, a.t())
#torch.cuda.synchronize() #torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6) #print((time.time()-t0)/iters*1e6)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment