Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
e0e697b1
Commit
e0e697b1
authored
Nov 06, 2022
by
Tim Dettmers
Browse files
Fixed blockwise test and logic.
parent
6bc2b992
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
11 deletions
+9
-11
bitsandbytes/functional.py
bitsandbytes/functional.py
+4
-6
tests/test_functional.py
tests/test_functional.py
+5
-5
No files found.
bitsandbytes/functional.py
View file @
e0e697b1
...
@@ -466,7 +466,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
...
@@ -466,7 +466,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
if
absmax
is
None
:
if
absmax
is
None
:
n
=
A
.
numel
()
n
=
A
.
numel
()
blocksize
=
(
blocksize
if
A
.
device
.
type
==
'c
p
u'
else
4096
)
blocksize
=
(
blocksize
if
A
.
device
.
type
==
'cu
da
'
else
4096
)
blocks
=
n
//
blocksize
blocks
=
n
//
blocksize
blocks
+=
1
if
n
%
blocksize
>
0
else
0
blocks
+=
1
if
n
%
blocksize
>
0
else
0
absmax
=
torch
.
zeros
((
blocks
,),
device
=
A
.
device
)
absmax
=
torch
.
zeros
((
blocks
,),
device
=
A
.
device
)
...
@@ -550,17 +550,15 @@ def dequantize_blockwise(
...
@@ -550,17 +550,15 @@ def dequantize_blockwise(
if
A
.
device
.
type
!=
'cpu'
:
if
A
.
device
.
type
!=
'cpu'
:
if
blocksize
not
in
[
2048
,
4096
]:
if
blocksize
not
in
[
2048
,
4096
,
1024
,
512
]:
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048 4096]"
)
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048
,
4096
, 1024, 512
]"
)
is_on_gpu
([
A
,
out
])
is_on_gpu
([
A
,
out
])
if
out
.
dtype
==
torch
.
float32
:
if
out
.
dtype
==
torch
.
float32
:
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
elif
out
.
dtype
==
torch
.
float16
:
elif
out
.
dtype
==
torch
.
float16
:
lib
.
cdequantize_blockwise_fp16
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize_blockwise_fp16
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
else
:
else
:
lib
.
cdequantize_blockwise_cpu_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_longlong
(
blocksize
),
ct
.
c_longlong
(
A
.
numel
()))
lib
.
cdequantize_blockwise_cpu_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_longlong
(
blocksize
),
ct
.
c_longlong
(
A
.
numel
()))
...
...
tests/test_functional.py
View file @
e0e697b1
...
@@ -157,8 +157,8 @@ def test_dynamic_blockwise_quantization():
...
@@ -157,8 +157,8 @@ def test_dynamic_blockwise_quantization():
reldiffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
...
@@ -173,13 +173,13 @@ def test_dynamic_blockwise_quantization():
...
@@ -173,13 +173,13 @@ def test_dynamic_blockwise_quantization():
diffs
=
[]
diffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
#
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
assert
abserr
<
0.0035
assert
abserr
<
0.0035
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment