Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
19a7adca
Commit
19a7adca
authored
Sep 11, 2022
by
Tim Dettmers
Browse files
Fixed 2^31 max size issue for cpu blockwise quant.
parent
f0ae860c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
107 additions
and
128 deletions
+107
-128
bitsandbytes/functional.py
bitsandbytes/functional.py
+14
-76
csrc/common.cpp
csrc/common.cpp
+4
-4
csrc/common.h
csrc/common.h
+6
-4
csrc/cpu_ops.cpp
csrc/cpu_ops.cpp
+52
-37
csrc/cpu_ops.h
csrc/cpu_ops.h
+4
-3
csrc/pythonInterface.c
csrc/pythonInterface.c
+2
-2
tests/test_functional.py
tests/test_functional.py
+25
-2
No files found.
bitsandbytes/functional.py
View file @
19a7adca
...
...
@@ -369,13 +369,7 @@ def estimate_quantiles(
return
out
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
,
)
->
Tensor
:
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
,
blocksize
=
4096
)
->
Tensor
:
"""
Quantize tensor A in blocks of size 4096 values.
...
...
@@ -412,9 +406,9 @@ def quantize_blockwise(
if
absmax
is
None
:
n
=
A
.
numel
()
num_
blocks
=
4096
blocks
=
n
//
num_
blocks
blocks
+=
1
if
n
%
num_
blocks
>
0
else
0
blocks
ize
=
(
blocksize
if
A
.
device
.
type
==
'cpu'
else
4096
)
blocks
=
n
//
blocks
ize
blocks
+=
1
if
n
%
blocks
ize
>
0
else
0
absmax
=
torch
.
zeros
((
blocks
,),
device
=
A
.
device
)
if
out
is
None
:
...
...
@@ -426,46 +420,18 @@ def quantize_blockwise(
assert
rand
.
numel
()
>=
1024
rand_offset
=
random
.
randint
(
0
,
1023
)
if
A
.
dtype
==
torch
.
float32
:
lib
.
cquantize_blockwise_stochastic_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()),
)
lib
.
cquantize_blockwise_stochastic_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()))
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cquantize_blockwise_stochastic_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()),
)
lib
.
cquantize_blockwise_stochastic_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
else
:
if
A
.
dtype
==
torch
.
float32
:
lib
.
cquantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
lib
.
cquantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cquantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
lib
.
cquantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
...
...
@@ -473,13 +439,7 @@ def quantize_blockwise(
else
:
# cpu
assert
rand
is
None
lib
.
cquantize_blockwise_cpu_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
lib
.
cquantize_blockwise_cpu_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_longlong
(
blocksize
),
ct
.
c_longlong
(
A
.
numel
()))
return
out
,
(
absmax
,
code
)
...
...
@@ -529,43 +489,21 @@ def dequantize_blockwise(
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
)
if
blocksize
not
in
[
2048
,
4096
]:
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048 4096]"
)
if
A
.
device
.
type
!=
'cpu'
:
if
blocksize
not
in
[
2048
,
4096
]:
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048 4096]"
)
is_on_gpu
([
A
,
out
])
if
out
.
dtype
==
torch
.
float32
:
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()),
)
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
elif
out
.
dtype
==
torch
.
float16
:
lib
.
cdequantize_blockwise_fp16
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()),
)
lib
.
cdequantize_blockwise_fp16
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
else
:
lib
.
cdequantize_blockwise_cpu_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
lib
.
cdequantize_blockwise_cpu_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_longlong
(
blocksize
),
ct
.
c_longlong
(
A
.
numel
()))
return
out
...
...
csrc/common.cpp
View file @
19a7adca
...
...
@@ -12,16 +12,16 @@ void *quantize_block(void *arguments) {
// 1. find absmax in block
float
absmax_block
=
-
FLT_MAX
;
for
(
int
i
=
args
->
block_idx
;
i
<
args
->
block_end
;
i
++
)
for
(
long
long
i
=
args
->
block_idx
;
i
<
args
->
block_end
;
i
++
)
absmax_block
=
fmax
(
absmax_block
,
fabs
(
args
->
A
[
i
]));
args
->
absmax
[
args
->
block_idx
/
BLOCK_SIZE
]
=
absmax_block
;
args
->
absmax
[
args
->
block_idx
/
args
->
blocksize
]
=
absmax_block
;
for
(
int
i
=
args
->
block_idx
;
i
<
args
->
block_end
;
i
++
)
{
for
(
long
long
i
=
args
->
block_idx
;
i
<
args
->
block_end
;
i
++
)
{
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
float
normed_value
=
args
->
A
[
i
]
/
absmax_block
;
int
idx
=
args
->
bin_searcher
->
scalar
(
normed_value
);
long
long
idx
=
args
->
bin_searcher
->
scalar
(
normed_value
);
// 4. check minimal distance
// The binary search returns always the value to the left, which might not be the closest value
...
...
csrc/common.h
View file @
19a7adca
...
...
@@ -5,18 +5,20 @@
using
namespace
BinSearch
;
#define BLOCK_SIZE 16384
struct
quantize_block_args
{
BinAlgo
<
Scalar
,
float
,
Direct2
>
*
bin_searcher
;
float
*
code
;
float
*
A
;
float
*
absmax
;
unsigned
char
*
out
;
int
block_end
;
int
block_idx
;
int
threadidx
;
long
long
block_end
;
long
long
block_idx
;
long
long
threadidx
;
long
long
blocksize
;
};
#define BLOCK_SIZE 4096
void
*
quantize_block
(
void
*
arguments
);
...
...
csrc/cpu_ops.cpp
View file @
19a7adca
...
...
@@ -4,54 +4,69 @@
using
namespace
BinSearch
;
void
dequantize_cpu
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
n
)
{
for
(
int
block_idx
=
0
;
block_idx
<
n
;
block_idx
+=
BLOCK_SIZE
)
{
int
valid_items
=
n
-
block_idx
>=
BLOCK_SIZE
?
BLOCK_SIZE
:
n
-
block_idx
;
int
block_end
=
block_idx
+
valid_items
;
for
(
int
i
=
block_idx
;
i
<
block_end
;
i
++
)
out
[
i
]
=
code
[
A
[
i
]]
*
absmax
[
block_idx
/
BLOCK_SIZE
];
void
dequantize_cpu
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
)
{
for
(
long
long
block_idx
=
0
;
block_idx
<
n
;
block_idx
+=
blocksize
)
{
long
long
valid_items
=
n
-
block_idx
>=
blocksize
?
blocksize
:
n
-
block_idx
;
long
long
block_end
=
block_idx
+
valid_items
;
for
(
long
long
i
=
block_idx
;
i
<
block_end
;
i
++
)
out
[
i
]
=
code
[
A
[
i
]]
*
absmax
[
block_idx
/
blocksize
];
}
}
void
quantize_cpu
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
n
)
{
void
quantize_cpu
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
)
{
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code
[
0
]
=
-
1.0
f
;
int
num_blocks
=
n
/
BLOCK_SIZE
;
num_blocks
+=
n
%
BLOCK_SIZE
==
0
?
0
:
1
;
pthread_t
*
threads
=
(
pthread_t
*
)
malloc
(
sizeof
(
pthread_t
)
*
num_blocks
);
struct
quantize_block_args
**
args
=
(
quantize_block_args
**
)
malloc
(
num_blocks
*
sizeof
(
quantize_block_args
*
));
for
(
int
i
=
0
;
i
<
num_blocks
;
i
++
)
args
[
i
]
=
(
quantize_block_args
*
)
malloc
(
sizeof
(
quantize_block_args
));
long
long
num_blocks
=
n
/
blocksize
;
num_blocks
+=
n
%
blocksize
==
0
?
0
:
1
;
const
uint32
elements_code
=
256
;
BinAlgo
<
Scalar
,
float
,
Direct2
>
bin_searcher
(
code
,
elements_code
);
for
(
int
block_idx
=
0
;
block_idx
<
n
;
block_idx
+=
BLOCK_SIZE
)
{
int
valid_items
=
n
-
block_idx
>=
BLOCK_SIZE
?
BLOCK_SIZE
:
n
-
block_idx
;
int
block_end
=
block_idx
+
valid_items
;
struct
quantize_block_args
*
arg
=
args
[
block_idx
/
BLOCK_SIZE
];
arg
->
bin_searcher
=
&
bin_searcher
;
arg
->
code
=
code
;
arg
->
A
=
A
;
arg
->
absmax
=
absmax
;
arg
->
out
=
out
;
arg
->
block_end
=
block_end
;
arg
->
block_idx
=
block_idx
;
arg
->
threadidx
=
block_idx
/
BLOCK_SIZE
;
pthread_create
(
&
threads
[
block_idx
/
BLOCK_SIZE
],
NULL
,
&
quantize_block
,
(
void
*
)
arg
);
}
int
thread_wave_size
=
256
;
// we chunk the thresds into waves of 256 since the max limit is
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
for
(
long
long
offset
=
0
;
offset
<
num_blocks
;
offset
+=
thread_wave_size
)
{
pthread_t
*
threads
=
(
pthread_t
*
)
malloc
(
sizeof
(
pthread_t
)
*
thread_wave_size
);
struct
quantize_block_args
**
args
=
(
quantize_block_args
**
)
malloc
(
thread_wave_size
*
sizeof
(
quantize_block_args
*
));
for
(
long
long
i
=
0
;
i
<
thread_wave_size
;
i
++
)
args
[
i
]
=
(
quantize_block_args
*
)
malloc
(
sizeof
(
quantize_block_args
));
for
(
int
i
=
0
;
i
<
num_blocks
;
i
++
)
int
err
=
pthread_join
(
threads
[
i
],
NULL
);
int
chunks_processed
=
0
;
for
(
long
long
block_idx
=
offset
*
blocksize
;
block_idx
<
n
;
block_idx
+=
blocksize
)
{
long
long
valid_items
=
n
-
block_idx
>=
blocksize
?
blocksize
:
n
-
block_idx
;
long
long
block_end
=
block_idx
+
valid_items
;
struct
quantize_block_args
*
arg
=
args
[
chunks_processed
];
arg
->
bin_searcher
=
&
bin_searcher
;
arg
->
code
=
code
;
arg
->
A
=
A
;
arg
->
absmax
=
absmax
;
arg
->
out
=
out
;
arg
->
block_end
=
block_end
;
arg
->
block_idx
=
block_idx
;
arg
->
threadidx
=
block_idx
/
blocksize
;
arg
->
blocksize
=
blocksize
;
pthread_create
(
&
threads
[
chunks_processed
],
NULL
,
&
quantize_block
,
(
void
*
)
arg
);
chunks_processed
+=
1
;
if
(
chunks_processed
==
thread_wave_size
){
break
;
}
}
for
(
int
i
=
0
;
i
<
thread_wave_size
;
i
++
)
int
err
=
pthread_join
(
threads
[
i
],
NULL
);
free
(
threads
);
for
(
int
i
=
0
;
i
<
thread_wave_size
;
i
++
)
free
(
args
[
i
]);
free
(
args
);
}
free
(
threads
);
for
(
int
i
=
0
;
i
<
num_blocks
;
i
++
)
free
(
args
[
i
]);
free
(
args
);
}
csrc/cpu_ops.h
View file @
19a7adca
#ifndef BITSANDBYTES_CPU_OPS_H
#define BITSANDBYTES_CPU_OPS_H
#include <iostream>
#include <stdio.h>
void
quantize_cpu
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
n
);
void
dequantize_cpu
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
n
);
void
quantize_cpu
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
);
void
dequantize_cpu
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
);
#endif
csrc/pythonInterface.c
View file @
19a7adca
...
...
@@ -287,7 +287,7 @@ extern "C"
void
cextractOutliers_ampere
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
){
extractOutliers_ampere
(
A
,
idx
,
out
,
idx_size
,
rows
,
cols
);
}
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
c
on
st
int
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
c
on
st
int
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
n
);
}
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
l
on
g
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
l
on
g
long
blocksize
,
long
long
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
}
tests/test_functional.py
View file @
19a7adca
...
...
@@ -1815,14 +1815,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
batch_size
=
1
seqdim
=
1
values
=
[]
#
values.append((batch_size, seqdim, 768, 4 * 768))
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
values
.
append
((
batch_size
,
seqdim
,
12288
,
4
*
12288
))
#
values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
...
...
@@ -2125,3 +2125,26 @@ def test_extract_outliers():
assert
outliers2
.
shape
[
1
]
==
idx
.
numel
()
torch
.
testing
.
assert_allclose
(
outliers1
,
outliers2
)
def
test_blockwise_cpu_large
():
diffs
=
[]
reldiffs
=
[]
batch
=
128
seq
=
128
hidden
=
14336
for
blocksize
in
[
4096
,
16384
]:
for
i
in
range
(
2
):
A1
=
torch
.
randn
(
batch
,
seq
,
hidden
,
device
=
'cpu'
)
t0
=
time
.
time
()
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
print
(
time
.
time
()
-
t0
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diffs
[
-
1
]
<
0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment