Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
caf18325
Commit
caf18325
authored
Nov 06, 2022
by
Tim Dettmers
Browse files
Added k-bit linear quantization.
parent
1efb87d8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
4 deletions
+60
-4
bitsandbytes/functional.py
bitsandbytes/functional.py
+10
-4
tests/test_functional.py
tests/test_functional.py
+50
-0
No files found.
bitsandbytes/functional.py
View file @
caf18325
...
@@ -130,11 +130,17 @@ class Cusparse_Context(object):
...
@@ -130,11 +130,17 @@ class Cusparse_Context(object):
return
cls
.
_instance
return
cls
.
_instance
def
create_linear_map
(
signed
=
True
):
def
create_linear_map
(
signed
=
True
,
bits
=
8
):
if
signed
:
sign
=
(
-
1.0
if
signed
else
0.0
)
return
torch
.
linspace
(
-
1.0
,
1.0
,
256
)
values
=
torch
.
linspace
(
sign
,
1.0
,
2
**
bits
)
gap
=
256
-
values
.
numel
()
if
gap
==
0
:
return
values
else
:
else
:
return
torch
.
linspace
(
0.0
,
1.0
,
256
)
l
=
values
.
numel
()
//
2
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
return
torch
.
Tensor
(
values
[:
l
].
tolist
()
+
[
0
]
*
gap
+
values
[
l
:].
tolist
())
def
create_fp8_map
(
signed
=
True
,
exponent_bits
=
5
,
precision_bits
=
2
):
def
create_fp8_map
(
signed
=
True
,
exponent_bits
=
5
,
precision_bits
=
2
):
...
...
tests/test_functional.py
View file @
caf18325
...
@@ -2091,3 +2091,53 @@ def test_fp8_quant():
...
@@ -2091,3 +2091,53 @@ def test_fp8_quant():
print
(
3
,
sum
(
abserr
)
/
len
(
abserr
))
print
(
3
,
sum
(
abserr
)
/
len
(
abserr
))
print
(
3
,
sum
(
relerr
)
/
len
(
relerr
))
print
(
3
,
sum
(
relerr
)
/
len
(
relerr
))
def
test_few_bit_quant
():
for
bits
in
range
(
2
,
9
):
code
=
F
.
create_linear_map
(
True
,
bits
=
bits
).
cuda
()
assert
code
.
numel
()
==
256
print
(
bits
)
for
i
in
range
(
100
):
values
=
torch
.
randn
(
1
,
24
,
device
=
'cuda'
)
values
/=
values
.
abs
().
max
()
#values[values.abs() < 1e-6] += 1e-5
q1
=
[]
v1
=
[]
for
v
in
values
[
0
]:
idx
=
torch
.
abs
(
v
-
code
).
argmin
()
q1
.
append
(
idx
.
item
())
v1
.
append
(
code
[
idx
].
item
())
q1
=
torch
.
Tensor
(
q1
).
cuda
()
v1
=
torch
.
Tensor
(
v1
).
cuda
()
q2
,
S2
=
F
.
quantize
(
values
,
code
=
code
)
v2
=
F
.
dequantize
(
q2
,
S2
)
idx
=
torch
.
isclose
(
q1
.
int
(),
q2
.
int
())
if
idx
.
sum
():
# some weird cases
err1
=
torch
.
abs
(
v1
-
values
).
mean
()
err2
=
torch
.
abs
(
v2
-
values
).
mean
()
assert
err2
<=
err1
else
:
torch
.
testing
.
assert_allclose
(
q1
,
q2
)
#print(e_bits, p_bits)
#abserr = []
#relerr = []
#for i in range(100):
# A1 = torch.randn(1024, 1024, device="cuda")
# C, SC = F.quantize_blockwise(A1, code=code)
# A2 = F.dequantize_blockwise(C, SC)
# diff = torch.abs(A1 - A2)
# reldiff = diff/torch.abs(A1+1e-8)
# abserr.append(diff.mean().item())
# relerr.append(reldiff.mean().item())
# #assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment