Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
98cbc4bc
Commit
98cbc4bc
authored
Nov 06, 2022
by
Tim Dettmers
Browse files
Added k-bit fp8 map.
parent
caf18325
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
52 deletions
+52
-52
bitsandbytes/functional.py
bitsandbytes/functional.py
+11
-5
tests/test_functional.py
tests/test_functional.py
+41
-47
No files found.
bitsandbytes/functional.py
View file @
98cbc4bc
...
@@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8):
...
@@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8):
return
torch
.
Tensor
(
values
[:
l
].
tolist
()
+
[
0
]
*
gap
+
values
[
l
:].
tolist
())
return
torch
.
Tensor
(
values
[:
l
].
tolist
()
+
[
0
]
*
gap
+
values
[
l
:].
tolist
())
def
create_fp8_map
(
signed
=
True
,
exponent_bits
=
5
,
precision_bits
=
2
):
def
create_fp8_map
(
signed
=
True
,
exponent_bits
=
5
,
precision_bits
=
2
,
total_bits
=
8
):
e
=
exponent_bits
e
=
exponent_bits
p
=
precision_bits
p
=
precision_bits
assert
e
+
p
==
7
has_sign
=
1
if
signed
else
0
assert
e
+
p
==
total_bits
-
has_sign
# the exponent is biased to 2^(e-1) -1 == 0
# the exponent is biased to 2^(e-1) -1 == 0
evalues
=
[]
evalues
=
[]
pvalues
=
[]
pvalues
=
[]
for
i
,
val
in
enumerate
(
range
(
-
((
2
**
(
exponent_bits
-
1
))),
2
**
(
exponent_bits
-
1
),
1
)):
for
i
,
val
in
enumerate
(
range
(
-
((
2
**
(
exponent_bits
-
has_sign
))),
2
**
(
exponent_bits
-
has_sign
),
1
)):
evalues
.
append
(
2
**
val
)
evalues
.
append
(
2
**
val
)
...
@@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
...
@@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
value
+=
pval
*
(
2
**-
(
i
+
1
))
value
+=
pval
*
(
2
**-
(
i
+
1
))
pvalues
.
append
(
value
)
pvalues
.
append
(
value
)
assert
len
(
evalues
)
*
len
(
pvalues
)
==
128
assert
len
(
evalues
)
*
len
(
pvalues
)
==
2
**
(
total_bits
-
has_sign
)
values
=
[]
values
=
[]
for
ev
in
evalues
:
for
ev
in
evalues
:
for
pv
in
pvalues
:
for
pv
in
pvalues
:
if
signed
:
values
.
append
(
-
ev
*
pv
)
values
.
append
(
-
ev
*
pv
)
values
.
append
(
ev
*
pv
)
values
.
append
(
ev
*
pv
)
if
total_bits
<
8
:
gap
=
256
-
len
(
values
)
for
i
in
range
(
gap
):
values
.
append
(
0
)
values
.
sort
()
values
.
sort
()
code
=
torch
.
Tensor
(
values
)
code
=
torch
.
Tensor
(
values
)
code
/=
code
.
max
()
code
/=
code
.
max
()
...
...
tests/test_functional.py
View file @
98cbc4bc
...
@@ -11,7 +11,7 @@ import bitsandbytes as bnb
...
@@ -11,7 +11,7 @@ import bitsandbytes as bnb
from
bitsandbytes
import
functional
as
F
from
bitsandbytes
import
functional
as
F
torch
.
set_printoptions
(
torch
.
set_printoptions
(
precision
=
4
,
sci_mode
=
False
,
linewidth
=
120
,
edgeitems
=
20
,
threshold
=
10000
precision
=
5
,
sci_mode
=
False
,
linewidth
=
120
,
edgeitems
=
20
,
threshold
=
10000
)
)
k
=
20
k
=
20
...
@@ -2095,12 +2095,21 @@ def test_fp8_quant():
...
@@ -2095,12 +2095,21 @@ def test_fp8_quant():
def
test_few_bit_quant
():
def
test_few_bit_quant
():
for
bits
in
range
(
2
,
9
):
for
bits
in
range
(
2
,
9
):
for
method
in
[
'linear'
,
'fp8'
]:
code
=
None
if
method
==
'linear'
:
code
=
F
.
create_linear_map
(
True
,
bits
=
bits
).
cuda
()
code
=
F
.
create_linear_map
(
True
,
bits
=
bits
).
cuda
()
elif
method
==
'fp8'
:
ebits
=
math
.
ceil
(
bits
/
2
)
pbits
=
bits
-
ebits
-
1
code
=
F
.
create_fp8_map
(
True
,
ebits
,
pbits
,
bits
).
cuda
()
print
(
ebits
,
pbits
,
bits
)
print
(
code
)
assert
code
.
numel
()
==
256
assert
code
.
numel
()
==
256
print
(
bits
)
print
(
bits
)
for
i
in
range
(
10
0
):
for
i
in
range
(
10
):
values
=
torch
.
randn
(
1
,
2
4
,
device
=
'cuda'
)
values
=
torch
.
randn
(
1
,
3
2
,
device
=
'cuda'
)
values
/=
values
.
abs
().
max
()
values
/=
values
.
abs
().
max
()
#values[values.abs() < 1e-6] += 1e-5
#values[values.abs() < 1e-6] += 1e-5
...
@@ -2126,18 +2135,3 @@ def test_few_bit_quant():
...
@@ -2126,18 +2135,3 @@ def test_few_bit_quant():
else
:
else
:
torch
.
testing
.
assert_allclose
(
q1
,
q2
)
torch
.
testing
.
assert_allclose
(
q1
,
q2
)
#print(e_bits, p_bits)
#abserr = []
#relerr = []
#for i in range(100):
# A1 = torch.randn(1024, 1024, device="cuda")
# C, SC = F.quantize_blockwise(A1, code=code)
# A2 = F.dequantize_blockwise(C, SC)
# diff = torch.abs(A1 - A2)
# reldiff = diff/torch.abs(A1+1e-8)
# abserr.append(diff.mean().item())
# relerr.append(reldiff.mean().item())
# #assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment