Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
eb028e6e
Commit
eb028e6e
authored
Nov 19, 2022
by
Tim Dettmers
Browse files
Fixed k-bit quantization maps.
parent
08fa2e7b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
28 deletions
+69
-28
bitsandbytes/functional.py
bitsandbytes/functional.py
+46
-16
tests/test_functional.py
tests/test_functional.py
+23
-12
No files found.
bitsandbytes/functional.py
View file @
eb028e6e
...
...
@@ -7,6 +7,7 @@ import operator
import
random
import
torch
import
itertools
import
math
from
typing
import
Tuple
from
torch
import
Tensor
...
...
@@ -130,10 +131,17 @@ class Cusparse_Context(object):
return
cls
.
_instance
def
create_linear_map
(
signed
=
True
,
total_bits
=
8
):
def
create_linear_map
(
signed
=
True
,
total_bits
=
8
,
add_zero
=
True
):
sign
=
(
-
1.0
if
signed
else
0.0
)
values
=
torch
.
linspace
(
sign
,
1.0
,
2
**
total_bits
)
total_values
=
2
**
total_bits
if
add_zero
or
total_bits
<
8
:
# add a zero
# since we simulate less bits by having zeros in the data type, we
# we need to center the quantization around zero and as such lose
# a single value
total_values
=
(
2
**
total_bits
if
not
signed
else
2
**
total_bits
-
1
)
values
=
torch
.
linspace
(
sign
,
1.0
,
total_values
)
gap
=
256
-
values
.
numel
()
if
gap
==
0
:
return
values
...
...
@@ -155,20 +163,28 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
evalues
.
append
(
2
**
val
)
lst
=
list
(
itertools
.
product
([
0
,
1
],
repeat
=
precision_bits
))
for
bit_pattern
in
lst
:
value
=
1
for
i
,
pval
in
enumerate
(
list
(
bit_pattern
)):
value
+=
pval
*
(
2
**-
(
i
+
1
))
pvalues
.
append
(
value
)
assert
len
(
evalues
)
*
len
(
pvalues
)
==
2
**
(
total_bits
-
has_sign
)
values
=
[]
for
ev
in
evalues
:
for
pv
in
pvalues
:
lst
=
list
(
itertools
.
product
([
0
,
1
],
repeat
=
precision_bits
))
#for ev in evalues:
bias
=
2
**
(
exponent_bits
-
1
)
-
1
for
evalue
in
range
(
2
**
(
exponent_bits
)):
for
bit_pattern
in
lst
:
value
=
(
1
if
evalue
!=
0
else
0
)
for
i
,
pval
in
enumerate
(
list
(
bit_pattern
)):
value
+=
pval
*
(
2
**-
(
i
+
1
))
if
evalue
==
0
:
# subnormals
value
=
value
*
2
**-
(
bias
-
1
)
else
:
# normals
value
=
value
*
2
**-
(
evalue
-
bias
-
2
)
values
.
append
(
value
)
if
signed
:
values
.
append
(
-
ev
*
pv
)
values
.
append
(
ev
*
pv
)
values
.
append
(
-
value
)
assert
len
(
values
)
==
2
**
total_bits
values
.
sort
()
if
total_bits
<
8
:
gap
=
256
-
len
(
values
)
for
i
in
range
(
gap
):
...
...
@@ -176,7 +192,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
values
.
sort
()
code
=
torch
.
Tensor
(
values
)
code
/=
code
.
max
()
code
[
127
]
=
0
return
code
...
...
@@ -232,6 +247,20 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data
.
sort
()
return
Tensor
(
data
)
def
create_quantile_map
(
A
,
total_bits
=
8
):
q
=
estimate_quantiles
(
A
,
num_quantiles
=
2
**
total_bits
-
1
)
q
=
q
.
tolist
()
q
.
append
(
0
)
gap
=
256
-
len
(
q
)
for
i
in
range
(
gap
):
q
.
append
(
0
)
q
.
sort
()
q
=
Tensor
(
q
)
q
=
q
/
q
.
abs
().
max
()
return
q
def
get_special_format_str
():
if
not
torch
.
cuda
.
is_available
():
return
'col_turing'
...
...
@@ -422,6 +451,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
post_call
(
device
)
if
num_quantiles
<
256
:
step
=
round
(
256
/
num_quantiles
)
idx
=
torch
.
linspace
(
0
,
255
,
num_quantiles
).
long
().
to
(
A
.
device
)
out
=
out
[
idx
]
...
...
tests/test_functional.py
View file @
eb028e6e
...
...
@@ -2113,15 +2113,11 @@ def test_few_bit_quant():
code
=
F
.
create_dynamic_map
(
True
,
bits
-
0
,
bits
).
cuda
()
elif
method
==
'quantile'
:
values
=
torch
.
randn
(
2048
,
2048
,
device
=
'cuda'
)
q
=
F
.
estimate_quantiles
(
values
,
offset
=
1
/
(
2
*
(
2
**
bits
)),
num_quantiles
=
2
**
bits
)
gap
=
256
-
q
.
numel
()
q
=
q
.
tolist
()
for
i
in
range
(
gap
):
q
.
append
(
0
)
q
=
torch
.
Tensor
(
q
).
cuda
()
q
/=
q
.
abs
().
max
()
code
,
idx
=
torch
.
sort
(
q
)
code
=
F
.
create_quantile_map
(
values
,
bits
).
cuda
()
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have two zeros
assert
torch
.
unique
(
code
).
numel
()
in
[
2
**
bits
,
2
**
bits
-
1
],
f
'bits:
{
bits
}
, method:
{
method
}
'
#print(method, (code==0).sum())
assert
code
.
numel
()
==
256
for
i
in
range
(
10
):
...
...
@@ -2140,8 +2136,8 @@ def test_few_bit_quant():
q1
=
torch
.
Tensor
(
q1
).
cuda
()
v1
=
torch
.
Tensor
(
v1
).
cuda
()
q2
,
S2
=
F
.
quantize
(
values
,
code
=
code
)
v2
=
F
.
dequantize
(
q2
,
S2
)
q2
,
S2
=
F
.
quantize
_blockwise
(
values
,
code
=
code
)
v2
=
F
.
dequantize
_blockwise
(
q2
,
S2
)
idx
=
torch
.
isclose
(
q1
.
int
(),
q2
.
int
())
err2
=
torch
.
abs
(
v2
-
values
)
...
...
@@ -2150,11 +2146,12 @@ def test_few_bit_quant():
if
idx
.
sum
():
# some weird cases
err1
=
torch
.
abs
(
v1
-
values
).
mean
()
assert
err2
.
mean
()
<=
err1
#
assert err2.mean() <= err1
else
:
torch
.
testing
.
assert_allclose
(
q1
,
q2
)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False
def
test_kbit_quantile_estimation
():
...
...
@@ -2165,6 +2162,20 @@ def test_kbit_quantile_estimation():
val1
=
torch
.
Tensor
(
norm
.
ppf
(
p
)).
cuda
()
val2
=
F
.
estimate_quantiles
(
data
,
offset
=
0
,
num_quantiles
=
2
**
bits
)
err
=
torch
.
abs
(
val1
-
val2
).
mean
()
assert
err
<
0.038
for
i
in
range
(
100
):
data
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
)
for
bits
in
range
(
2
,
4
):
total_values
=
2
**
bits
-
1
p
=
np
.
linspace
(
0
,
1
,
2
*
total_values
+
1
)
idx
=
np
.
arange
(
1
,
2
*
total_values
+
1
,
2
)
p
=
p
[
idx
]
offset
=
1
/
(
2
*
total_values
)
p
=
np
.
linspace
(
offset
,
1
-
offset
,
total_values
)
val1
=
torch
.
Tensor
(
norm
.
ppf
(
p
)).
cuda
()
val2
=
F
.
estimate_quantiles
(
data
,
num_quantiles
=
2
**
bits
-
1
)
err
=
torch
.
abs
(
val1
-
val2
).
mean
()
assert
err
<
0.035
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment