Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
0f9d3020
Commit
0f9d3020
authored
Apr 19, 2023
by
Tim Dettmers
Browse files
Added nested quantization for blockwise quantization.
parent
7dc198fe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
42 deletions
+55
-42
bitsandbytes/functional.py
bitsandbytes/functional.py
+18
-7
tests/test_functional.py
tests/test_functional.py
+37
-35
No files found.
bitsandbytes/functional.py
View file @
0f9d3020
...
@@ -541,7 +541,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
...
@@ -541,7 +541,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return
out
return
out
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
,
blocksize
=
4096
)
->
Tensor
:
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
,
blocksize
=
4096
,
nested
=
False
)
->
Tensor
:
"""
"""
Quantize tensor A in blocks of size 4096 values.
Quantize tensor A in blocks of size 4096 values.
...
@@ -586,7 +586,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
...
@@ -586,7 +586,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
if
A
.
device
.
type
!=
'cpu'
:
if
A
.
device
.
type
!=
'cpu'
:
assert
blocksize
in
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
32
]
assert
blocksize
in
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
]
cblocksize
=
ct
.
c_int32
(
blocksize
)
cblocksize
=
ct
.
c_int32
(
blocksize
)
prev_device
=
pre_call
(
A
.
device
)
prev_device
=
pre_call
(
A
.
device
)
code
=
code
.
to
(
A
.
device
)
code
=
code
.
to
(
A
.
device
)
...
@@ -616,7 +616,15 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
...
@@ -616,7 +616,15 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
assert
rand
is
None
assert
rand
is
None
lib
.
cquantize_blockwise_cpu_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_longlong
(
blocksize
),
ct
.
c_longlong
(
A
.
numel
()))
lib
.
cquantize_blockwise_cpu_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_longlong
(
blocksize
),
ct
.
c_longlong
(
A
.
numel
()))
state
=
[
absmax
,
code
,
blocksize
]
if
nested
:
offset
=
absmax
.
mean
()
absmax
-=
offset
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
blocksize
,
nested
=
False
)
state
=
[
qabsmax
,
code
,
blocksize
,
nested
,
offset
,
state2
]
else
:
state
=
[
absmax
,
code
,
blocksize
,
nested
,
None
,
None
]
return
out
,
state
return
out
,
state
...
@@ -628,6 +636,7 @@ def dequantize_blockwise(
...
@@ -628,6 +636,7 @@ def dequantize_blockwise(
code
:
Tensor
=
None
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
4096
,
blocksize
:
int
=
4096
,
nested
=
False
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Dequantizes blockwise quantized values.
Dequantizes blockwise quantized values.
...
@@ -665,13 +674,15 @@ def dequantize_blockwise(
...
@@ -665,13 +674,15 @@ def dequantize_blockwise(
if
quant_state
is
None
:
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
,
blocksize
)
quant_state
=
(
absmax
,
code
,
blocksize
)
else
:
else
:
absmax
,
code
,
blocksize
=
quant_state
absmax
,
code
,
blocksize
,
nested
,
offset
,
state2
=
quant_state
if
nested
:
absmax
=
dequantize_blockwise
(
absmax
,
state2
)
absmax
+=
offset
if
A
.
device
.
type
!=
'cpu'
:
if
A
.
device
.
type
!=
'cpu'
:
device
=
pre_call
(
A
.
device
)
device
=
pre_call
(
A
.
device
)
code
=
code
.
to
(
A
.
device
)
code
=
code
.
to
(
A
.
device
)
if
blocksize
not
in
[
2048
,
4096
,
1024
,
512
,
256
,
128
,
64
,
32
]:
if
blocksize
not
in
[
2048
,
4096
,
1024
,
512
,
256
,
128
,
64
]:
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
)
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
)
is_on_gpu
([
A
,
absmax
,
out
])
is_on_gpu
([
A
,
absmax
,
out
])
if
out
.
dtype
==
torch
.
float32
:
if
out
.
dtype
==
torch
.
float32
:
...
@@ -736,7 +747,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
...
@@ -736,7 +747,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
zeros
(((
n
+
1
)
//
2
,
1
),
dtype
=
torch
.
uint8
,
device
=
A
.
device
)
out
=
torch
.
zeros
(((
n
+
1
)
//
2
,
1
),
dtype
=
torch
.
uint8
,
device
=
A
.
device
)
assert
blocksize
in
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
32
]
assert
blocksize
in
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
]
prev_device
=
pre_call
(
A
.
device
)
prev_device
=
pre_call
(
A
.
device
)
is_on_gpu
([
A
,
out
,
absmax
])
is_on_gpu
([
A
,
out
,
absmax
])
...
...
tests/test_functional.py
View file @
0f9d3020
...
@@ -150,42 +150,44 @@ def test_dynamic_quantization():
...
@@ -150,42 +150,44 @@ def test_dynamic_quantization():
assert
diff
<
0.004
assert
diff
<
0.004
def
test_dynamic_blockwise_quantization
():
@
pytest
.
mark
.
parametrize
(
"nested"
,
[
False
,
True
],
ids
=
[
"False"
,
"True"
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
def
test_dynamic_blockwise_quantization
(
nested
,
blocksize
):
#print('')
#print('')
for
blocksize
in
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
32
]:
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
,
nested
=
nested
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
assert
abserr
<
0.011
assert
abserr
<
0.011
assert
relerr
<
0.018
assert
relerr
<
0.018
print
(
'nested='
,
nested
,
'randn'
,
blocksize
,
sum
(
diffs
)
/
len
(
diffs
))
#print('randn', blocksize, sum(diffs)/len(diffs))
print
(
'nested='
,
nested
,
'randn'
,
blocksize
,
sum
(
reldiffs
)
/
len
(
reldiffs
))
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs
=
[]
diffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
,
nested
=
nested
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
assert
abserr
<
0.0035
assert
abserr
<
0.0035
assert
relerr
<
0.015
assert
relerr
<
0.015
print
(
'nested='
,
nested
,
'rand'
,
blocksize
,
sum
(
diffs
)
/
len
(
diffs
))
#print('rand', blocksize, sum(diffs)/len(diffs))
print
(
'nested='
,
nested
,
'rand'
,
blocksize
,
sum
(
reldiffs
)
/
len
(
reldiffs
))
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
def
test_dynamic_blockwise_stochastic_quantization
():
def
test_dynamic_blockwise_stochastic_quantization
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment