Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
2489d819
Commit
2489d819
authored
Feb 14, 2023
by
Tim Dettmers
Browse files
Added more blocksizes for stochastic rounding; fixed dequant blocksize.
parent
2dfa3ce1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
37 additions
and
24 deletions
+37
-24
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+3
-3
bitsandbytes/functional.py
bitsandbytes/functional.py
+2
-3
csrc/kernels.cu
csrc/kernels.cu
+12
-0
csrc/ops.cu
csrc/ops.cu
+6
-8
csrc/pythonInterface.c
csrc/pythonInterface.c
+4
-4
tests/test_functional.py
tests/test_functional.py
+10
-6
No files found.
bitsandbytes/autograd/_functions.py
View file @
2489d819
...
...
@@ -413,10 +413,10 @@ class MatMulFP8(torch.autograd.Function):
# 2. MatmulnN
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
1024
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
).
to
(
A
.
dtype
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
,
blocksize
=
1024
).
to
(
A
.
dtype
)
cB
,
state
=
F
.
quantize_blockwise
(
B
,
code
=
fw_code
,
blocksize
=
1024
)
fp8B
=
F
.
dequantize_blockwise
(
cB
,
state
).
to
(
B
.
dtype
)
fp8B
=
F
.
dequantize_blockwise
(
cB
,
state
,
blocksize
=
1024
).
to
(
B
.
dtype
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
...
...
@@ -443,7 +443,7 @@ class MatMulFP8(torch.autograd.Function):
grad_A
,
grad_B
=
None
,
None
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_output
,
code
=
ctx
.
bw_code
,
blocksize
=
1024
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
).
to
(
grad_output
.
dtype
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
,
blocksize
=
1024
).
to
(
grad_output
.
dtype
)
# Cast grad_output to fp16
if
len
(
grad_output
.
shape
)
==
3
:
...
...
bitsandbytes/functional.py
View file @
2489d819
...
...
@@ -508,13 +508,12 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
code
=
code
.
to
(
A
.
device
)
if
rand
is
not
None
:
is_on_gpu
([
code
,
A
,
out
,
absmax
,
rand
])
assert
blocksize
==
4096
assert
rand
.
numel
()
>=
1024
rand_offset
=
random
.
randint
(
0
,
1023
)
if
A
.
dtype
==
torch
.
float32
:
lib
.
cquantize_blockwise_stochastic_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_stochastic_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
cblocksize
,
ct
.
c_int
(
A
.
numel
()))
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cquantize_blockwise_stochastic_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_stochastic_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
cblocksize
,
ct
.
c_int
(
A
.
numel
()))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
else
:
...
...
csrc/kernels.cu
View file @
2489d819
...
...
@@ -2797,16 +2797,28 @@ template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half
template
__global__
void
kQuantizeBlockwise
<
float
,
4096
,
4
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
2048
,
4
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
2048
,
4
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
2048
,
4
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
2048
,
4
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
1024
,
4
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
1024
,
4
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
1024
,
4
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
1024
,
4
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
512
,
2
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
512
,
2
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
512
,
2
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
512
,
2
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
256
,
2
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
256
,
2
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
256
,
2
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
256
,
2
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
128
,
2
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
128
,
2
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
128
,
2
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
128
,
2
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
64
,
1
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
64
,
1
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
64
,
1
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
64
,
1
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
4096
,
1024
,
4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
4096
,
1024
,
4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
n
);
...
...
csrc/ops.cu
View file @
2489d819
...
...
@@ -54,23 +54,21 @@ template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A,
{
int
num_blocks
=
n
/
blocksize
;
num_blocks
=
n
%
blocksize
==
0
?
num_blocks
:
num_blocks
+
1
;
if
(
STOCHASTIC
==
1
)
assert
(
blocksize
==
4096
);
if
(
blocksize
==
4096
)
kQuantizeBlockwise
<
T
,
4096
,
4
,
STOCHASTIC
><<<
num_blocks
,
1024
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
2048
)
kQuantizeBlockwise
<
T
,
2048
,
4
,
0
><<<
num_blocks
,
512
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
2048
,
4
,
STOCHASTIC
><<<
num_blocks
,
512
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
1024
)
kQuantizeBlockwise
<
T
,
1024
,
4
,
0
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
1024
,
4
,
STOCHASTIC
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
512
)
kQuantizeBlockwise
<
T
,
512
,
2
,
0
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
512
,
2
,
STOCHASTIC
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
256
)
kQuantizeBlockwise
<
T
,
256
,
2
,
0
><<<
num_blocks
,
128
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
256
,
2
,
STOCHASTIC
><<<
num_blocks
,
128
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
128
)
kQuantizeBlockwise
<
T
,
128
,
2
,
0
><<<
num_blocks
,
64
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
128
,
2
,
STOCHASTIC
><<<
num_blocks
,
64
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
64
)
kQuantizeBlockwise
<
T
,
64
,
1
,
0
><<<
num_blocks
,
64
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
64
,
1
,
STOCHASTIC
><<<
num_blocks
,
64
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
...
...
csrc/pythonInterface.c
View file @
2489d819
...
...
@@ -77,8 +77,8 @@ void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){
void
quantizeBlockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_stochastic_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise
<
half
,
1
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
4096
,
n
);
}
void
quantizeBlockwise_stochastic_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise
<
float
,
1
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
4096
,
n
);
}
void
quantizeBlockwise_stochastic_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
1
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
blocksize
,
n
);
}
void
quantizeBlockwise_stochastic_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
1
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
...
...
@@ -142,8 +142,8 @@ extern "C"
void
cdequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
int
n
){
dequantize
(
code
,
A
,
out
,
n
);
}
void
cquantize_blockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_stochastic_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise_stochastic_fp16
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
}
void
cquantize_blockwise_stochastic_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise_stochastic_fp32
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
}
void
cquantize_blockwise_stochastic_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_stochastic_fp16
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
blocksize
,
n
);
}
void
cquantize_blockwise_stochastic_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_stochastic_fp32
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
...
...
tests/test_functional.py
View file @
2489d819
...
...
@@ -188,21 +188,25 @@ def test_dynamic_blockwise_quantization():
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
def
test_dynamic_blockwise_stochastic_quantization
():
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
def
test_dynamic_blockwise_stochastic_quantization
(
blocksize
):
diffs
=
[]
reldiffs
=
[]
rand
=
torch
.
rand
(
1024
).
cuda
()
err
=
0
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
C1
,
S1
=
F
.
quantize_blockwise
(
A1
,
rand
=
rand
)
C2
,
S2
=
F
.
quantize_blockwise
(
A1
)
C1
,
S1
=
F
.
quantize_blockwise
(
A1
,
rand
=
rand
,
blocksize
=
blocksize
)
C2
,
S2
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C1
,
S1
,
blocksize
=
blocksize
)
err
+=
(
A1
-
A2
).
abs
().
mean
().
item
()
/
100
# a maximunm distance of quantized values of 1
torch
.
testing
.
assert_allclose
(
C1
,
C2
,
atol
=
1
,
rtol
=
0
)
fraction_smaller
=
(
C1
<
C2
).
float
().
sum
()
/
C1
.
numel
()
fraction_larger
=
(
C1
>
C2
).
float
().
sum
()
/
C1
.
numel
()
torch
.
testing
.
assert_allclose
(
fraction_larger
,
fraction_smaller
,
atol
=
0.01
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
fraction_larger
,
fraction_smaller
,
atol
=
0.01
,
rtol
=
0
)
assert
err
<
0.019
@
pytest
.
mark
.
parametrize
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment