Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
3ac5840c
Commit
3ac5840c
authored
Feb 04, 2023
by
Tim Dettmers
Browse files
Added fp4 quant/dequant and dequant optimizations.
parent
0f5c3948
Changes
9
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
471 additions
and
116 deletions
+471
-116
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+1
-1
bitsandbytes/cuda_setup/main.py
bitsandbytes/cuda_setup/main.py
+4
-0
bitsandbytes/functional.py
bitsandbytes/functional.py
+118
-5
csrc/kernels.cu
csrc/kernels.cu
+220
-68
csrc/kernels.cuh
csrc/kernels.cuh
+2
-2
csrc/ops.cu
csrc/ops.cu
+25
-29
csrc/ops.cuh
csrc/ops.cuh
+2
-2
csrc/pythonInterface.c
csrc/pythonInterface.c
+16
-7
tests/test_functional.py
tests/test_functional.py
+83
-2
No files found.
bitsandbytes/cextension.py
View file @
3ac5840c
...
@@ -9,7 +9,7 @@ from bitsandbytes.cuda_setup.main import CUDASetup
...
@@ -9,7 +9,7 @@ from bitsandbytes.cuda_setup.main import CUDASetup
setup
=
CUDASetup
.
get_instance
()
setup
=
CUDASetup
.
get_instance
()
if
setup
.
initialized
!=
True
:
if
not
setup
.
initialized
:
setup
.
run_cuda_setup
()
setup
.
run_cuda_setup
()
if
'BITSANDBYTES_NOWELCOME'
not
in
os
.
environ
or
str
(
os
.
environ
[
'BITSANDBYTES_NOWELCOME'
])
==
'0'
:
if
'BITSANDBYTES_NOWELCOME'
not
in
os
.
environ
or
str
(
os
.
environ
[
'BITSANDBYTES_NOWELCOME'
])
==
'0'
:
setup
.
print_log_stack
()
setup
.
print_log_stack
()
...
...
bitsandbytes/cuda_setup/main.py
View file @
3ac5840c
...
@@ -35,6 +35,9 @@ class CUDASetup:
...
@@ -35,6 +35,9 @@ class CUDASetup:
raise
RuntimeError
(
"Call get_instance() instead"
)
raise
RuntimeError
(
"Call get_instance() instead"
)
def
generate_instructions
(
self
):
def
generate_instructions
(
self
):
if
getattr
(
self
,
'error'
,
False
):
return
print
(
self
.
error
)
self
.
error
=
True
if
self
.
cuda
is
None
:
if
self
.
cuda
is
None
:
self
.
add_log_entry
(
'CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.'
)
self
.
add_log_entry
(
'CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.'
)
self
.
add_log_entry
(
'CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.'
)
self
.
add_log_entry
(
'CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.'
)
...
@@ -84,6 +87,7 @@ class CUDASetup:
...
@@ -84,6 +87,7 @@ class CUDASetup:
self
.
has_printed
=
False
self
.
has_printed
=
False
self
.
lib
=
None
self
.
lib
=
None
self
.
initialized
=
False
self
.
initialized
=
False
self
.
error
=
False
def
run_cuda_setup
(
self
):
def
run_cuda_setup
(
self
):
self
.
initialized
=
True
self
.
initialized
=
True
...
...
bitsandbytes/functional.py
View file @
3ac5840c
...
@@ -168,7 +168,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
...
@@ -168,7 +168,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
values
=
[]
values
=
[]
lst
=
list
(
itertools
.
product
([
0
,
1
],
repeat
=
precision_bits
))
lst
=
list
(
itertools
.
product
([
0
,
1
],
repeat
=
precision_bits
))
#for ev in evalues:
#for ev in evalues:
bias
=
2
**
(
exponent_bits
-
1
)
-
1
bias
=
2
**
(
exponent_bits
-
1
)
+
1
print
(
bias
)
for
evalue
in
range
(
2
**
(
exponent_bits
)):
for
evalue
in
range
(
2
**
(
exponent_bits
)):
for
bit_pattern
in
lst
:
for
bit_pattern
in
lst
:
value
=
(
1
if
evalue
!=
0
else
0
)
value
=
(
1
if
evalue
!=
0
else
0
)
...
@@ -176,10 +177,12 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
...
@@ -176,10 +177,12 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
value
+=
pval
*
(
2
**-
(
i
+
1
))
value
+=
pval
*
(
2
**-
(
i
+
1
))
if
evalue
==
0
:
if
evalue
==
0
:
# subnormals
# subnormals
value
=
value
*
2
**-
(
bias
-
1
)
value
=
value
*
2
**-
(
bias
)
else
:
else
:
# normals
# normals
value
=
value
*
2
**-
(
evalue
-
bias
-
2
)
print
(
value
,
1
)
value
=
value
*
2
**-
(
evalue
-
bias
-
1
)
print
(
value
,
2
)
values
.
append
(
value
)
values
.
append
(
value
)
if
signed
:
if
signed
:
values
.
append
(
-
value
)
values
.
append
(
-
value
)
...
@@ -193,7 +196,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
...
@@ -193,7 +196,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
values
.
append
(
0
)
values
.
append
(
0
)
values
.
sort
()
values
.
sort
()
code
=
torch
.
Tensor
(
values
)
code
=
torch
.
Tensor
(
values
)
code
/=
code
.
max
()
#
code /= code.max()
return
code
return
code
...
@@ -587,7 +590,7 @@ def dequantize_blockwise(
...
@@ -587,7 +590,7 @@ def dequantize_blockwise(
code
=
code
.
to
(
A
.
device
)
code
=
code
.
to
(
A
.
device
)
if
blocksize
not
in
[
2048
,
4096
,
1024
,
512
,
256
,
128
,
64
]:
if
blocksize
not
in
[
2048
,
4096
,
1024
,
512
,
256
,
128
,
64
]:
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
)
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
)
is_on_gpu
([
A
,
out
])
is_on_gpu
([
A
,
absmax
,
out
])
if
out
.
dtype
==
torch
.
float32
:
if
out
.
dtype
==
torch
.
float32
:
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
elif
out
.
dtype
==
torch
.
float16
:
elif
out
.
dtype
==
torch
.
float16
:
...
@@ -602,6 +605,116 @@ def dequantize_blockwise(
...
@@ -602,6 +605,116 @@ def dequantize_blockwise(
return
out
return
out
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
)
->
Tensor
:
"""
Quantize tensor A in blocks of FP4 values.
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
Parameters
----------
A : torch.Tensor
The input tensor.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
The output tensor (8-bit).
blocksize : int
The blocksize used in quantization.
Returns
-------
torch.Tensor:
The 8-bit tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype):
The quantization state to undo the quantization.
"""
if
A
.
device
.
type
!=
'cuda'
:
raise
NotImplementedError
(
f
'Device type not supported for FP4 quantization:
{
A
.
device
.
type
}
'
)
n
=
A
.
numel
()
input_shape
=
A
.
shape
if
absmax
is
None
:
blocks
=
n
//
blocksize
blocks
+=
1
if
n
%
blocksize
>
0
else
0
absmax
=
torch
.
zeros
((
blocks
,),
device
=
A
.
device
)
state
=
(
absmax
,
input_shape
,
A
.
dtype
)
if
out
is
None
:
out
=
torch
.
zeros
(((
n
+
1
)
//
2
,),
dtype
=
torch
.
uint8
,
device
=
A
.
device
)
assert
blocksize
in
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
]
prev_device
=
pre_call
(
A
.
device
)
is_on_gpu
([
A
,
out
,
absmax
])
if
A
.
dtype
==
torch
.
float32
:
lib
.
cquantize_blockwise_fp32_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cquantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
return
out
,
state
def
dequantize_fp4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
"""
Dequantizes FP4 blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
Parameters
----------
A : torch.Tensor
The input 8-bit tensor (packed 4-bit values).
quant_state : tuple(torch.Tensor, torch.Size, torch.dtype)
Tuple of absmax values, original tensor shape and original dtype.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
Dequantized output tensor.
Returns
-------
torch.Tensor:
Dequantized tensor.
"""
if
blocksize
not
in
[
2048
,
4096
,
1024
,
512
,
256
,
128
,
64
]:
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
)
if
quant_state
is
None
:
assert
absmax
is
not
None
and
out
is
not
None
shape
=
out
.
shape
dtype
=
out
.
dtype
else
:
absmax
,
shape
,
dtype
=
quant_state
if
out
is
None
:
out
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
n
=
out
.
numel
()
device
=
pre_call
(
A
.
device
)
is_on_gpu
([
A
,
absmax
,
out
])
if
out
.
dtype
==
torch
.
float32
:
lib
.
cdequantize_blockwise_fp32_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
elif
out
.
dtype
==
torch
.
float16
:
lib
.
cdequantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
return
out
def
quantize
(
A
:
Tensor
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
def
quantize
(
A
:
Tensor
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
if
code
is
None
:
if
code
is
None
:
if
"dynamic"
not
in
name2qmap
:
if
"dynamic"
not
in
name2qmap
:
...
...
csrc/kernels.cu
View file @
3ac5840c
This diff is collapsed.
Click to expand it.
csrc/kernels.cuh
View file @
3ac5840c
...
@@ -14,8 +14,8 @@ template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A,
...
@@ -14,8 +14,8 @@ template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A,
__global__
void
kQuantize
(
float
*
code
,
float
*
__restrict__
const
A
,
unsigned
char
*
out
,
const
int
n
);
__global__
void
kQuantize
(
float
*
code
,
float
*
__restrict__
const
A
,
unsigned
char
*
out
,
const
int
n
);
__global__
void
kDequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
const
int
n
);
__global__
void
kDequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
const
int
n
);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_PER_TH
,
int
STOCHASTIC
>
__global__
void
kQuantizeBlockwise
(
float
*
code
,
T
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_PER_TH
,
int
STOCHASTIC
,
int
FP4
>
__global__
void
kQuantizeBlockwise
(
float
*
code
,
T
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
THREADS
,
int
NUM_PER_TH
>
__global__
void
kDequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
const
int
n
);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
THREADS
,
int
NUM_PER_TH
,
int
FP4
>
__global__
void
kDequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
const
int
blocksize
,
const
int
n
);
template
<
typename
T
,
int
OPTIMIZER
,
int
BLOCK_SIZE
,
int
NUM_VALS
>
template
<
typename
T
,
int
OPTIMIZER
,
int
BLOCK_SIZE
,
int
NUM_VALS
>
__global__
void
kPreconditionOptimizer32bit2State
(
T
*
g
,
T
*
p
,
__global__
void
kPreconditionOptimizer32bit2State
(
T
*
g
,
T
*
p
,
...
...
csrc/ops.cu
View file @
3ac5840c
...
@@ -50,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
...
@@ -50,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
template
<
typename
T
,
int
STOCHASTIC
>
void
quantizeBlockwise
(
float
*
code
,
T
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
)
template
<
typename
T
,
int
STOCHASTIC
,
int
FP4
>
void
quantizeBlockwise
(
float
*
code
,
T
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
)
{
{
int
num_blocks
=
n
/
blocksize
;
int
num_blocks
=
n
/
blocksize
;
num_blocks
=
n
%
blocksize
==
0
?
num_blocks
:
num_blocks
+
1
;
num_blocks
=
n
%
blocksize
==
0
?
num_blocks
:
num_blocks
+
1
;
...
@@ -58,42 +58,34 @@ template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A,
...
@@ -58,42 +58,34 @@ template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A,
assert
(
blocksize
==
4096
);
assert
(
blocksize
==
4096
);
if
(
blocksize
==
4096
)
if
(
blocksize
==
4096
)
kQuantizeBlockwise
<
T
,
4096
,
4
,
STOCHASTIC
><<<
num_blocks
,
1024
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
4096
,
4
,
STOCHASTIC
,
0
><<<
num_blocks
,
1024
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
2048
)
else
if
(
blocksize
==
2048
)
kQuantizeBlockwise
<
T
,
2048
,
4
,
0
><<<
num_blocks
,
512
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
2048
,
4
,
0
,
FP4
><<<
num_blocks
,
512
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
1024
)
else
if
(
blocksize
==
1024
)
kQuantizeBlockwise
<
T
,
1024
,
4
,
0
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
1024
,
4
,
0
,
FP4
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
512
)
else
if
(
blocksize
==
512
)
kQuantizeBlockwise
<
T
,
512
,
2
,
0
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
512
,
2
,
0
,
FP4
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
256
)
else
if
(
blocksize
==
256
)
kQuantizeBlockwise
<
T
,
256
,
2
,
0
><<<
num_blocks
,
128
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
256
,
2
,
0
,
FP4
><<<
num_blocks
,
128
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
128
)
else
if
(
blocksize
==
128
)
kQuantizeBlockwise
<
T
,
128
,
2
,
0
><<<
num_blocks
,
64
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
128
,
2
,
0
,
FP4
><<<
num_blocks
,
64
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
64
)
else
if
(
blocksize
==
64
)
kQuantizeBlockwise
<
T
,
64
,
1
,
0
><<<
num_blocks
,
64
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
64
,
2
,
0
,
FP4
><<<
num_blocks
,
32
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
template
<
typename
T
>
void
dequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
int
blocksize
,
const
int
n
)
template
<
typename
T
,
int
FP4
>
void
dequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
int
blocksize
,
const
int
n
)
{
{
int
num_blocks
=
n
/
blocksize
;
int
num_blocks
=
n
/
blocksize
;
num_blocks
=
n
%
blocksize
==
0
?
num_blocks
:
num_blocks
+
1
;
num_blocks
=
n
%
blocksize
==
0
?
num_blocks
:
num_blocks
+
1
;
if
(
blocksize
==
4096
)
int
tile_size
=
FP4
?
1024
:
512
;
kDequantizeBlockwise
<
T
,
4096
,
1024
,
4
><<<
num_blocks
,
4096
/
4
>>>
(
code
,
A
,
absmax
,
out
,
n
);
else
if
(
blocksize
==
2048
)
if
(
FP4
)
kDequantizeBlockwise
<
T
,
2048
,
512
,
4
><<<
num_blocks
,
2048
/
4
>>>
(
code
,
A
,
absmax
,
out
,
n
);
kDequantizeBlockwise
<
T
,
512
,
64
,
8
,
FP4
><<<
(
n
+
tile_size
-
1
)
/
tile_size
,
64
>>>
(
code
,
A
,
absmax
,
out
,
blocksize
/
2
,
n
);
else
if
(
blocksize
==
1024
)
else
kDequantizeBlockwise
<
T
,
1024
,
256
,
4
><<<
num_blocks
,
1024
/
4
>>>
(
code
,
A
,
absmax
,
out
,
n
);
kDequantizeBlockwise
<
T
,
512
,
64
,
8
,
FP4
><<<
(
n
+
tile_size
-
1
)
/
tile_size
,
64
>>>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
else
if
(
blocksize
==
512
)
kDequantizeBlockwise
<
T
,
512
,
256
,
2
><<<
num_blocks
,
512
/
2
>>>
(
code
,
A
,
absmax
,
out
,
n
);
else
if
(
blocksize
==
256
)
kDequantizeBlockwise
<
T
,
256
,
128
,
2
><<<
num_blocks
,
256
/
2
>>>
(
code
,
A
,
absmax
,
out
,
n
);
else
if
(
blocksize
==
128
)
kDequantizeBlockwise
<
T
,
128
,
64
,
2
><<<
num_blocks
,
128
/
2
>>>
(
code
,
A
,
absmax
,
out
,
n
);
else
if
(
blocksize
==
64
)
kDequantizeBlockwise
<
T
,
64
,
64
,
1
><<<
num_blocks
,
64
/
1
>>>
(
code
,
A
,
absmax
,
out
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
...
@@ -688,12 +680,16 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
...
@@ -688,12 +680,16 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
template
void
estimateQuantiles
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
estimateQuantiles
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
estimateQuantiles
(
float
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
estimateQuantiles
(
float
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
quantizeBlockwise
<
half
,
0
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
0
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
0
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
1
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
1
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
1
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
1
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
1
,
0
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
1
,
0
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
0
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
0
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
1
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
1
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
#define MAKE_optimizer32bit(name, gtype) \
#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
...
...
csrc/ops.cuh
View file @
3ac5840c
...
@@ -128,8 +128,8 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
...
@@ -128,8 +128,8 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
void
quantize
(
float
*
code
,
float
*
A
,
unsigned
char
*
out
,
int
n
);
void
quantize
(
float
*
code
,
float
*
A
,
unsigned
char
*
out
,
int
n
);
void
dequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
int
n
);
void
dequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
int
n
);
template
<
typename
T
,
int
STOCHASTIC
>
void
quantizeBlockwise
(
float
*
code
,
T
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
<
typename
T
,
int
STOCHASTIC
,
int
FP4
>
void
quantizeBlockwise
(
float
*
code
,
T
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
<
typename
T
>
void
dequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
int
block_size
,
const
int
n
);
template
<
typename
T
,
int
FP4
>
void
dequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
int
block_size
,
const
int
n
);
template
<
typename
T
,
int
OPTIMIZER
>
void
optimizer32bit
(
T
*
g
,
T
*
p
,
template
<
typename
T
,
int
OPTIMIZER
>
void
optimizer32bit
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
float
max_unorm
,
float
param_norm
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
float
max_unorm
,
float
param_norm
,
...
...
csrc/pythonInterface.c
View file @
3ac5840c
...
@@ -75,13 +75,17 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
...
@@ -75,13 +75,17 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
void
percentileClipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
float
>
(
g
,
gnorm_vec
,
step
,
n
);
}
void
percentileClipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
float
>
(
g
,
gnorm_vec
,
step
,
n
);
}
void
percentileClipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
half
>
(
g
,
gnorm_vec
,
step
,
n
);
}
void
percentileClipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
half
>
(
g
,
gnorm_vec
,
step
,
n
);
}
void
quantizeBlockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
0
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
0
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_stochastic_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise
<
half
,
1
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
4096
,
n
);
}
void
quantizeBlockwise_stochastic_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise
<
half
,
1
,
0
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
4096
,
n
);
}
void
quantizeBlockwise_stochastic_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise
<
float
,
1
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
4096
,
n
);
}
void
quantizeBlockwise_stochastic_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise
<
float
,
1
,
0
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
4096
,
n
);
}
void
quantizeBlockwise_fp16_fp4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
1
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
quantizeBlockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
1
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
0
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
0
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
1
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
1
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
...
@@ -148,6 +152,11 @@ extern "C"
...
@@ -148,6 +152,11 @@ extern "C"
void
cdequantize_blockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp16_fp4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
#define MAKE_CFUNC32(name, gtype, gbits) \
#define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
...
...
tests/test_functional.py
View file @
3ac5840c
...
@@ -152,7 +152,7 @@ def test_dynamic_quantization():
...
@@ -152,7 +152,7 @@ def test_dynamic_quantization():
def
test_dynamic_blockwise_quantization
():
def
test_dynamic_blockwise_quantization
():
#print('')
#print('')
for
blocksize
in
[
4096
,
2048
,
1024
,
512
]:
for
blocksize
in
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
]:
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
...
@@ -2189,7 +2189,88 @@ def test_bench_dequantization():
...
@@ -2189,7 +2189,88 @@ def test_bench_dequantization():
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
F
.
dequantize_blockwise
(
qa
,
SA
,
blocksize
=
2048
)
#F.dequantize_blockwise(qa, SA, blocksize=2048)
qa
,
SA
=
F
.
quantize_blockwise
(
a
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/1e6)
#print((time.time()-t0)/1e6)
def
test_fp4_quant
():
vals
=
list
(
product
([
0
,
1
],
repeat
=
4
))
code
=
{}
for
bits
in
vals
:
result
=
0
bias
=
3
sign
,
e1
,
e2
,
p1
=
bits
idx
=
sign
*
8
+
e1
*
4
+
e2
*
2
+
p1
*
1
sign
=
-
1.0
if
sign
else
1.0
exp
=
e1
*
2
+
e2
*
1
if
exp
==
0
:
# sub-normal
if
p1
==
0
:
result
=
0
else
:
result
=
sign
*
0.0625
else
:
# normal
exp
=
2
**
(
-
exp
+
bias
+
1
)
frac
=
1.5
if
p1
else
1.0
result
=
sign
*
exp
*
frac
code
[
idx
]
=
result
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
qa
,
SA
=
F
.
quantize_fp4
(
A1
,
blocksize
=
64
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
#qa, SA = F.quantize_fp4(A1, blocksize=128)
#A2 = F.dequantize_fp4(qa, SA, blocksize=128)
#A1 = A1.flatten().sort()[0]
#A2 = A2.flatten().sort()[0]
#print(A1)
#print(A2)
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
A1
.
abs
().
float
()).
mean
()
err
=
err
.
mean
()
print
(
err
,
relerr
)
#assert err.item() < 0.1
#assert relerr.item() < 0.28
def
test_bench_fp4_dequant
():
blocksize
=
256
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
qa
,
SA
=
F
.
quantize_fp4
(
a
,
blocksize
=
blocksize
)
input_size
=
a
.
numel
()
/
2
output_size
=
a
.
numel
()
*
2
num_bytes
=
input_size
+
output_size
GB
=
num_bytes
/
1e9
max_theoretical_s
=
GB
/
768
print
(
max_theoretical_s
*
1e6
)
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
iters
=
5
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
F
.
dequantize_fp4
(
qa
,
SA
,
blocksize
=
blocksize
)
#b.copy_(a)
torch
.
cuda
.
synchronize
()
print
((
time
.
time
()
-
t0
)
/
iters
*
1e6
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
torch
.
matmul
(
b
,
a
.
t
())
torch
.
cuda
.
synchronize
()
print
((
time
.
time
()
-
t0
)
/
iters
*
1e6
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment