Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c17fb8eb
Commit
c17fb8eb
authored
Mar 29, 2024
by
Matthew Douglas
Browse files
Fix 4bit quantization with blocksize=4096
parent
fd9d072e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
9 deletions
+28
-9
bitsandbytes/functional.py
bitsandbytes/functional.py
+4
-3
csrc/ops.cu
csrc/ops.cu
+1
-1
tests/test_functional.py
tests/test_functional.py
+23
-5
No files found.
bitsandbytes/functional.py
View file @
c17fb8eb
...
@@ -1087,11 +1087,12 @@ def get_4bit_type(typename, device=None, blocksize=64):
...
@@ -1087,11 +1087,12 @@ def get_4bit_type(typename, device=None, blocksize=64):
if
data
is
None
:
if
data
is
None
:
raise
NotImplementedError
(
f
"Typename
{
typename
}
not supported"
)
raise
NotImplementedError
(
f
"Typename
{
typename
}
not supported"
)
data
=
Tensor
(
data
)
data
=
torch
.
tensor
(
data
,
device
=
device
)
data
/=
data
.
abs
().
max
()
data
.
div_
(
data
.
abs
().
max
())
assert
data
.
numel
()
==
16
assert
data
.
numel
()
==
16
return
data
.
to
(
device
)
return
data
def
quantize_fp4
(
def
quantize_fp4
(
...
...
csrc/ops.cu
View file @
c17fb8eb
...
@@ -58,7 +58,7 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
...
@@ -58,7 +58,7 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
num_blocks
=
n
%
blocksize
==
0
?
num_blocks
:
num_blocks
+
1
;
num_blocks
=
n
%
blocksize
==
0
?
num_blocks
:
num_blocks
+
1
;
if
(
blocksize
==
4096
)
if
(
blocksize
==
4096
)
kQuantizeBlockwise
<
T
,
4096
,
4
,
STOCHASTIC
,
0
><<<
num_blocks
,
1024
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
4096
,
4
,
STOCHASTIC
,
DATA_TYPE
><<<
num_blocks
,
1024
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
2048
)
else
if
(
blocksize
==
2048
)
kQuantizeBlockwise
<
T
,
2048
,
4
,
0
,
DATA_TYPE
><<<
num_blocks
,
512
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
2048
,
4
,
0
,
DATA_TYPE
><<<
num_blocks
,
512
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
1024
)
else
if
(
blocksize
==
1024
)
...
...
tests/test_functional.py
View file @
c17fb8eb
...
@@ -1928,7 +1928,9 @@ def test_bench_dequantization():
...
@@ -1928,7 +1928,9 @@ def test_bench_dequantization():
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
def
test_fp4_quant
(
dtype
):
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"fp4"
,
"nf4"
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
])
def
test_4bit_quant
(
dtype
,
quant_type
,
blocksize
):
vals
=
list
(
product
([
0
,
1
],
repeat
=
4
))
vals
=
list
(
product
([
0
,
1
],
repeat
=
4
))
code
=
{}
code
=
{}
...
@@ -1953,8 +1955,8 @@ def test_fp4_quant(dtype):
...
@@ -1953,8 +1955,8 @@ def test_fp4_quant(dtype):
code
[
idx
]
=
result
code
[
idx
]
=
result
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
,
dtype
=
dtype
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
,
dtype
=
dtype
)
qa
,
SA
=
F
.
quantize_
fp4
(
A1
,
blocksize
=
64
)
qa
,
SA
=
F
.
quantize_
4bit
(
A1
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
A2
=
F
.
dequantize_
fp4
(
qa
,
SA
)
A2
=
F
.
dequantize_
4bit
(
qa
,
SA
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
err
=
(
A1
-
A2
).
abs
().
float
()
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-8
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-8
)).
mean
()
...
@@ -1962,8 +1964,24 @@ def test_fp4_quant(dtype):
...
@@ -1962,8 +1964,24 @@ def test_fp4_quant(dtype):
err
=
err
.
mean
()
err
=
err
.
mean
()
assert
A2
.
dtype
==
dtype
assert
A2
.
dtype
==
dtype
# With larger block sizes, we can expect this to blow up.
# At blocksize>=1024, don't even bother looking at relerr.
if
blocksize
<=
64
:
assert
err
.
item
()
<
0.1
assert
err
.
item
()
<
0.1
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
elif
blocksize
<=
256
:
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.30
elif
blocksize
<=
512
:
assert
err
.
item
()
<
0.12
assert
relerr
.
item
()
<
0.31
elif
quant_type
==
"fp4"
:
# 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
assert
err
.
item
()
<
0.08
+
math
.
log2
(
blocksize
)
*
4e-2
else
:
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
assert
err
.
item
()
<
math
.
log2
(
blocksize
)
*
8e-2
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"fp4"
,
"nf4"
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"fp4"
,
"nf4"
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment