Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
02fd80cb
Commit
02fd80cb
authored
Jul 04, 2023
by
Tim Dettmers
Browse files
Added bfloat16 quantizations and tests.
parent
dfe6900b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
172 additions
and
118 deletions
+172
-118
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+4
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+30
-16
csrc/kernels.cu
csrc/kernels.cu
+44
-21
csrc/ops.cu
csrc/ops.cu
+15
-18
csrc/pythonInterface.c
csrc/pythonInterface.c
+41
-15
tests/test_functional.py
tests/test_functional.py
+38
-47
No files found.
bitsandbytes/autograd/_functions.py
View file @
02fd80cb
...
...
@@ -561,4 +561,7 @@ def matmul(
def
matmul_4bit
(
A
:
tensor
,
B
:
tensor
,
quant_state
:
List
,
out
:
tensor
=
None
,
bias
=
None
):
assert
quant_state
is
not
None
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
if
A
.
numel
()
==
A
.
shape
[
-
1
]
and
A
.
requires_grad
==
False
:
return
F
.
cutlass3_gemm
(
A
,
B
.
t
(),
out
,
state
=
quant_state
)
else
:
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
bitsandbytes/functional.py
View file @
02fd80cb
...
...
@@ -617,6 +617,8 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
lib
.
cquantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
cblocksize
,
ct
.
c_int
(
A
.
numel
()))
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cquantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
cblocksize
,
ct
.
c_int
(
A
.
numel
()))
elif
A
.
dtype
==
torch
.
bfloat16
:
lib
.
cquantize_blockwise_bf16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
cblocksize
,
ct
.
c_int
(
A
.
numel
()))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
...
...
@@ -629,11 +631,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
offset
=
absmax
.
mean
()
absmax
-=
offset
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
blocksize
,
nested
=
False
)
state
=
[
qabsmax
,
code
,
blocksize
,
nested
,
offset
,
state2
]
state
=
[
qabsmax
,
code
,
blocksize
,
nested
,
A
.
dtype
,
offset
,
state2
]
else
:
state
=
[
absmax
,
code
,
blocksize
,
nested
,
None
,
None
]
state
=
[
absmax
,
code
,
blocksize
,
nested
,
A
.
dtype
,
None
,
None
]
return
out
,
state
...
...
@@ -678,18 +678,16 @@ def dequantize_blockwise(
name2qmap
[
"dynamic"
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
"dynamic"
]
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
,
blocksize
)
assert
absmax
is
not
None
and
out
is
not
None
else
:
absmax
,
code
,
blocksize
,
nested
,
offset
,
state2
=
quant_state
if
nested
:
absmax
=
dequantize_blockwise
(
absmax
,
state2
)
absmax
+=
offset
quant_state
=
(
absmax
,
code
,
blocksize
,
False
,
torch
.
float32
,
None
,
None
)
absmax
,
code
,
blocksize
,
nested
,
dtype
,
offset
,
state2
=
quant_state
if
nested
:
absmax
=
dequantize_blockwise
(
absmax
,
state2
)
absmax
+=
offset
if
out
is
None
:
out
=
torch
.
empty
(
A
.
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
if
A
.
device
.
type
!=
'cpu'
:
device
=
pre_call
(
A
.
device
)
...
...
@@ -701,6 +699,8 @@ def dequantize_blockwise(
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
elif
out
.
dtype
==
torch
.
float16
:
lib
.
cdequantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
elif
out
.
dtype
==
torch
.
bfloat16
:
lib
.
cdequantize_blockwise_bf16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
...
...
@@ -774,6 +774,11 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
lib
.
cquantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
else
:
lib
.
cquantize_blockwise_fp16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
elif
A
.
dtype
==
torch
.
bfloat16
:
if
quant_type
==
'fp4'
:
lib
.
cquantize_blockwise_bf16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
else
:
lib
.
cquantize_blockwise_bf16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
...
...
@@ -860,6 +865,11 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
lib
.
cdequantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
else
:
lib
.
cdequantize_blockwise_fp16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
elif
out
.
dtype
==
torch
.
bfloat16
:
if
quant_type
==
'fp4'
:
lib
.
cdequantize_blockwise_bf16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
else
:
lib
.
cdequantize_blockwise_bf16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
...
...
@@ -1503,7 +1513,12 @@ def cutlass3_gemm(
ldc
=
ct
.
c_int32
(
ldc
)
if
B
.
dtype
==
torch
.
uint8
:
lib
.
cgemm_4bit_inference_naive
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
if
A
.
dtype
==
torch
.
float16
:
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
elif
A
.
dtype
==
torch
.
bfloat16
:
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
elif
A
.
dtype
==
torch
.
float32
:
lib
.
cgemm_host_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
)
elif
A
.
dtype
==
torch
.
float16
:
...
...
@@ -1515,7 +1530,6 @@ def cutlass3_gemm(
def
igemm
(
A
:
Tensor
,
B
:
Tensor
,
...
...
csrc/kernels.cu
View file @
02fd80cb
...
...
@@ -3540,7 +3540,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
unsigned
char
local_B_4bit
[
num_values_8bit
];
T
local_B
[
num_values_4bit
];
T
local_A
[
num_values_4bit
];
__shared__
half
quant_map
[
16
*
THREADS
];
__shared__
T
quant_map
[
16
*
THREADS
];
for
(
int
i
=
0
;
i
<
16
;
i
++
)
quant_map
[
threadIdx
.
x
+
(
i
*
blockDim
.
x
)]
=
nf4_data
[
i
];
...
...
@@ -3769,11 +3769,8 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
32
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
__nv_bfloat16
,
128
>(
int
M
,
int
N
,
int
K
,
__nv_bfloat16
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
...
...
@@ -3929,6 +3926,20 @@ MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
2048
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
1024
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
512
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
2048
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
1024
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
512
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
1
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
General8bit
)
...
...
@@ -3937,13 +3948,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise
(
float
,
256
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
2048
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
1024
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
512
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
1024
,
4
,
0
,
FP4
)
...
...
@@ -3951,13 +3955,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise
(
float
,
256
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
2048
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
1024
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
512
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
1024
,
4
,
0
,
NF4
)
...
...
@@ -3966,12 +3963,38 @@ MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
4096
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
4096
,
4
,
1
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
2048
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
1024
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
512
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
256
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
128
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
64
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
4096
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
2048
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
1024
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
512
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
256
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
128
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
64
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
4096
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
2048
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
1024
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
512
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
256
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
128
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
64
,
2
,
0
,
NF4
)
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
__nv_bfloat16
,
512
,
64
,
8
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
__nv_bfloat16
,
512
,
64
,
8
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
__nv_bfloat16
,
512
,
64
,
8
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
const
int
blocksize
,
const
int
n
);
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
...
...
csrc/ops.cu
View file @
02fd80cb
...
...
@@ -733,20 +733,8 @@ template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A,
{
int
num_blocks
=
(
m
+
3
)
/
4
;
//int num_blocks = m;
cout
<<
num_blocks
<<
endl
;
//cout << lda << endl;
//cout << ldb << endl;
//cout << ldc << endl;
//cout << m << endl;
//cout << n << endl;
//cout << k << endl;
kgemm_4bit_inference_naive
<
T
,
128
><<<
num_blocks
,
128
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
)
...
...
@@ -770,6 +758,7 @@ template void func<float, _MUL>(float *A, float *B, float value, long n);
template
void
gemm_4bit_inference
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
__nv_bfloat16
>(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
...
...
@@ -796,19 +785,27 @@ template void estimateQuantiles(half *A, float *code, float offset, int n);
template
void
estimateQuantiles
(
float
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
quantizeBlockwise
<
half
,
1
,
General8bit
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
1
,
General8bit
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
General8bit
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
General8bit
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
FP4
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
FP4
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
NF4
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
1
,
General8bit
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
General8bit
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
FP4
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
NF4
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
__nv_bfloat16
,
1
,
General8bit
>(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
__nv_bfloat16
,
0
,
General8bit
>(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
__nv_bfloat16
,
0
,
FP4
>(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
__nv_bfloat16
,
0
,
NF4
>(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
__nv_bfloat16
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
__nv_bfloat16
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
__nv_bfloat16
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
);
#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
...
...
csrc/pythonInterface.c
View file @
02fd80cb
...
...
@@ -28,9 +28,12 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
void
gemm_4bit_inference_naive
_fp16
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
__nv_bfloat16
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
...
...
@@ -106,19 +109,29 @@ void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){
void
percentileClipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
half
>
(
g
,
gnorm_vec
,
step
,
n
);
}
void
quantizeBlockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp16_fp4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp16_nf4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_bf16
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
__nv_bfloat16
,
0
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_bf16_fp4
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
__nv_bfloat16
,
0
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_bf16_nf4
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
__nv_bfloat16
,
0
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32_nf4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp32_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_bf16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
__nv_bfloat16
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_bf16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
__nv_bfloat16
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_bf16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
__nv_bfloat16
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
...
...
@@ -177,21 +190,31 @@ extern "C"
void
cestimate_quantiles_fp16
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
){
estimateQuantiles_fp16
(
A
,
code
,
offset
,
n
);
}
void
cquantize
(
float
*
code
,
float
*
A
,
unsigned
char
*
out
,
int
n
){
quantize
(
code
,
A
,
out
,
n
);
}
void
cdequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
int
n
){
dequantize
(
code
,
A
,
out
,
n
);
}
void
cquantize_blockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp
32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp
32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp
16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp
16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp16_fp4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp16_nf4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32_nf4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_bf16
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_bf16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_bf16_fp4
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_bf16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_bf16_nf4
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_bf16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_bf16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_bf16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_bf16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_bf16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_bf16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_bf16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
#define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
...
...
@@ -348,9 +371,6 @@ extern "C"
void
cgemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
cgemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
*
cget_managed_ptr
(
size_t
bytes
)
{
void
*
ptr
;
...
...
@@ -374,6 +394,12 @@ extern "C"
CMAKE_ELEMENTWISE_FUNC
(
arange
,
fp32
,
float
,
ARANGE
)
CMAKE_ELEMENTWISE_FUNC
(
_mul
,
fp32
,
float
,
_MUL
)
void
cgemm_4bit_inference_naive_fp16
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
cgemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
...
...
tests/test_functional.py
View file @
02fd80cb
...
...
@@ -154,34 +154,36 @@ def test_dynamic_quantization():
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
"fp32"
,
"fp16"
,
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"nested"
,
[
False
,
True
],
ids
=
[
"False"
,
"True"
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
def
test_dynamic_blockwise_quantization
(
nested
,
blocksize
):
def
test_dynamic_blockwise_quantization
(
dtype
,
nested
,
blocksize
):
#print('')
diffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
,
dtype
=
dtype
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
,
nested
=
nested
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diff
=
torch
.
abs
(
A1
-
A2
)
.
float
()
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
assert
abserr
<
0.011
assert
relerr
<
0.018
#print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
assert
A2
.
dtype
==
dtype
diffs
=
[]
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
,
dtype
=
dtype
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
,
nested
=
nested
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diff
=
torch
.
abs
(
A1
-
A2
)
.
float
()
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
...
...
@@ -189,6 +191,7 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
assert
abserr
<
0.0035
assert
relerr
<
0.015
assert
A2
.
dtype
==
dtype
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
...
...
@@ -1773,8 +1776,8 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
batch_size
=
32
seqdim
=
512
+
256
batch_size
=
1
seqdim
=
1
values
=
[]
#values.append((batch_size, seqdim, 768, 4 * 768))
#values.append((batch_size, seqdim, 1024, 4*1024))
...
...
@@ -1800,7 +1803,7 @@ def test_bench_matmul(batch, seq, model, hidden):
B_fp4
,
state
=
F
.
quantize_fp4
(
B
)
B_fp4_c
,
state_c
=
F
.
quantize_fp4
(
B
,
compress_statistics
=
True
)
B_nf4
,
state_nf4
=
F
.
quantize_nf4
(
B
)
B_nf4
,
state_nf4
=
F
.
quantize_nf4
(
B
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
False
).
cuda
().
half
()
linear8bit
.
eval
()
...
...
@@ -1813,6 +1816,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
F
.
cutlass3_gemm
(
A
,
B_nf4
.
t
(),
state
=
state_nf4
)
# warmup
for
i
in
range
(
iters
):
...
...
@@ -1844,7 +1848,8 @@ def test_bench_matmul(batch, seq, model, hidden):
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
bnb
.
matmul_4bit
(
A
,
B_nf4
.
t
(),
quant_state
=
state_nf4
)
#bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
F
.
cutlass3_gemm
(
A
,
B_nf4
.
t
(),
state
=
state_nf4
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
...
...
@@ -2221,7 +2226,8 @@ def test_bench_dequantization():
def
test_fp4_quant
():
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
"fp32"
,
"fp16"
,
"bf16"
])
def
test_fp4_quant
(
dtype
):
vals
=
list
(
product
([
0
,
1
],
repeat
=
4
))
code
=
{}
...
...
@@ -2243,7 +2249,7 @@ def test_fp4_quant():
result
=
sign
*
exp
*
frac
code
[
idx
]
=
result
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
(
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
,
dtype
=
dtype
)
qa
,
SA
=
F
.
quantize_fp4
(
A1
,
blocksize
=
64
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
...
...
@@ -2252,7 +2258,7 @@ def test_fp4_quant():
idx
=
err
>
1.0
err
=
err
.
mean
()
assert
A2
.
dtype
==
dtype
assert
err
.
item
()
<
0.1
assert
relerr
.
item
()
<
0.28
...
...
@@ -2409,20 +2415,16 @@ def test_cutlass3_gemm(dtype):
print
(
dim
,
(
max_err
.
item
(),
max_relerr
.
item
()))
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
'fp16'
,
'bf16'
])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
def
test_gemm_4bit
(
dtype
):
print
(
''
)
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
for
dim
in
[
2
*
4096
]:
#for dim in [5120]:
#for dim in [6656]:
#for dim in [4]:
for
dim
in
[
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]:
errs
=
[]
relerrs
=
[]
max_err
=
0
max_relerr
=
0
for
i
in
range
(
100
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
...
...
@@ -2443,14 +2445,13 @@ def test_gemm_4bit(dtype):
qB
,
state
=
F
.
quantize_nf4
(
B
)
F
.
dequantize_nf4
(
qB
,
state
)
#C
3
=
torch
.matmul(A, B.t())
#C
2
=
bnb
.matmul
_4bit
(A,
q
B.t()
, state
)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
C1
=
bnb
.
matmul
_4bit
(
A
,
q
B
.
t
()
,
state
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
#print(state)
#print(qB)
#print('')
#print(A)
#print(B)
...
...
@@ -2464,8 +2465,8 @@ def test_gemm_4bit(dtype):
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err
=
torch
.
abs
(
C1
-
C2
)
mag
=
torch
.
abs
(
C1
)
+
1e-
8
err
=
torch
.
abs
(
C1
-
C2
)
.
float
()
mag
=
torch
.
abs
(
C1
)
.
float
()
+
1e-
5
relerr
=
err
/
mag
max_err
=
max
(
err
.
max
(),
max_err
)
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
...
...
@@ -2476,27 +2477,17 @@ def test_gemm_4bit(dtype):
errs
.
append
(
err
)
relerrs
.
append
(
relerr
)
if
err
/
torch
.
abs
(
C1
).
mean
()
>
5e-5
or
err
>
3.2e-5
:
print
(
''
)
print
(
i
,
err
,
relerr
)
#print(A.flatten()[-6:])
#print(B.flatten()[-6:])
#out = A.flatten()[-6:]*B.flatten()[-6:]
#print(out)
#print(out[:-1].sum())
print
(
'='
*
80
)
#print(C1.flatten()[-6:])
#print(C2.flatten()[-6:])
#assert False, 'ERROR'
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
print
(
c
/
math
.
sqrt
(
dim
))
print
(
''
)
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
(
max_err
.
item
(),
max_relerr
.
item
()))
#print('')
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
#print(dim, (max_err.item(), max_relerr.item()))
#print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
#print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
0.011
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.15
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_managed
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment