Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c05dd42d
Commit
c05dd42d
authored
Sep 13, 2022
by
Tim Dettmers
Browse files
Fixed cpu blockwise quantization for small input tensors.
parent
d8dbf3a9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
21 deletions
+22
-21
csrc/cpu_ops.cpp
csrc/cpu_ops.cpp
+7
-6
tests/test_functional.py
tests/test_functional.py
+15
-15
No files found.
csrc/cpu_ops.cpp
View file @
c05dd42d
...
@@ -30,11 +30,12 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
...
@@ -30,11 +30,12 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
for
(
long
long
offset
=
0
;
offset
<
num_blocks
;
offset
+=
thread_wave_size
)
for
(
long
long
offset
=
0
;
offset
<
num_blocks
;
offset
+=
thread_wave_size
)
{
{
pthread_t
*
threads
=
(
pthread_t
*
)
malloc
(
sizeof
(
pthread_t
)
*
thread_wave_size
);
long
long
valid_chunks
=
num_blocks
-
offset
>=
thread_wave_size
?
thread_wave_size
:
num_blocks
-
offset
;
pthread_t
*
threads
=
(
pthread_t
*
)
malloc
(
sizeof
(
pthread_t
)
*
valid_chunks
);
struct
quantize_block_args
**
args
=
(
quantize_block_args
**
)
malloc
(
thread_wave_size
*
sizeof
(
quantize_block_args
*
));
struct
quantize_block_args
**
args
=
(
quantize_block_args
**
)
malloc
(
valid_chunks
*
sizeof
(
quantize_block_args
*
));
for
(
long
long
i
=
0
;
i
<
thread_wave_size
;
i
++
)
for
(
long
long
i
=
0
;
i
<
valid_chunks
;
i
++
)
args
[
i
]
=
(
quantize_block_args
*
)
malloc
(
sizeof
(
quantize_block_args
));
args
[
i
]
=
(
quantize_block_args
*
)
malloc
(
sizeof
(
quantize_block_args
));
int
chunks_processed
=
0
;
int
chunks_processed
=
0
;
...
@@ -56,14 +57,14 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
...
@@ -56,14 +57,14 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
pthread_create
(
&
threads
[
chunks_processed
],
NULL
,
&
quantize_block
,
(
void
*
)
arg
);
pthread_create
(
&
threads
[
chunks_processed
],
NULL
,
&
quantize_block
,
(
void
*
)
arg
);
chunks_processed
+=
1
;
chunks_processed
+=
1
;
if
(
chunks_processed
==
thread_wave_size
){
break
;
}
if
(
chunks_processed
==
valid_chunks
){
break
;
}
}
}
for
(
int
i
=
0
;
i
<
thread_wave_size
;
i
++
)
for
(
int
i
=
0
;
i
<
valid_chunks
;
i
++
)
int
err
=
pthread_join
(
threads
[
i
],
NULL
);
int
err
=
pthread_join
(
threads
[
i
],
NULL
);
free
(
threads
);
free
(
threads
);
for
(
int
i
=
0
;
i
<
thread_wave_size
;
i
++
)
for
(
int
i
=
0
;
i
<
valid_chunks
;
i
++
)
free
(
args
[
i
]);
free
(
args
[
i
]);
free
(
args
);
free
(
args
);
...
...
tests/test_functional.py
View file @
c05dd42d
...
@@ -2133,18 +2133,18 @@ def test_blockwise_cpu_large():
...
@@ -2133,18 +2133,18 @@ def test_blockwise_cpu_large():
reldiffs
=
[]
reldiffs
=
[]
batch
=
128
batch
=
128
seq
=
128
seq
=
128
hidden
=
14336
for
hidden
in
[
128
,
14336
]:
for
blocksize
in
[
4096
,
16384
]:
for
blocksize
in
[
4096
,
16384
]:
for
i
in
range
(
2
):
for
i
in
range
(
2
):
A1
=
torch
.
randn
(
batch
,
seq
,
hidden
,
device
=
'cpu'
)
A1
=
torch
.
randn
(
batch
,
seq
,
hidden
,
device
=
'cpu'
)
t0
=
time
.
time
()
t0
=
time
.
time
()
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
print
(
time
.
time
()
-
t0
)
print
(
time
.
time
()
-
t0
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diffs
[
-
1
]
<
0.011
assert
diffs
[
-
1
]
<
0.011
# print(sum(diffs)/len(diffs))
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
# print(sum(reldiffs)/len(reldiffs))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment