Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
4b88d69d
Commit
4b88d69d
authored
Jul 09, 2023
by
Tim Dettmers
Browse files
Added abitrary data types; fixed a bug for small matrices.
parent
eefbf602
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
98 additions
and
109 deletions
+98
-109
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+1
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+55
-15
csrc/kernels.cu
csrc/kernels.cu
+9
-4
csrc/kernels.cuh
csrc/kernels.cuh
+1
-1
csrc/ops.cu
csrc/ops.cu
+4
-4
csrc/ops.cuh
csrc/ops.cuh
+1
-1
csrc/pythonInterface.c
csrc/pythonInterface.c
+8
-8
tests/test_functional.py
tests/test_functional.py
+19
-75
No files found.
bitsandbytes/autograd/_functions.py
View file @
4b88d69d
...
...
@@ -562,6 +562,6 @@ def matmul(
def
matmul_4bit
(
A
:
tensor
,
B
:
tensor
,
quant_state
:
List
,
out
:
tensor
=
None
,
bias
=
None
):
assert
quant_state
is
not
None
if
A
.
numel
()
==
A
.
shape
[
-
1
]
and
A
.
requires_grad
==
False
:
return
F
.
cutlass3_gemm
(
A
,
B
.
t
(),
out
,
state
=
quant_state
)
return
F
.
gemv_4bit
(
A
,
B
.
t
(),
out
,
state
=
quant_state
)
else
:
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
bitsandbytes/functional.py
View file @
4b88d69d
...
...
@@ -240,17 +240,19 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
v1
=
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
9
)[:
-
1
]).
tolist
()
v2
=
[
0
]
*
(
256
-
15
)
## we have 15 non-zero values in this data type
v3
=
(
-
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
8
)[:
-
1
])).
tolist
()
v
=
v1
+
v2
+
v3
else
:
v1
=
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
8
)[:
-
1
]).
tolist
()
v2
=
[
0
]
*
(
256
-
14
)
## we have 14 non-zero values in this data type
v3
=
(
-
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
8
)[:
-
1
])).
tolist
()
v
=
v1
+
v2
+
v3
values
=
torch
.
Tensor
(
v
)
values
=
values
.
sort
().
values
values
/=
values
.
max
()
assert
values
.
numel
()
==
256
return
values
def
create_fp8_map
(
signed
=
True
,
exponent_bits
=
5
,
precision_bits
=
2
,
total_bits
=
8
):
...
...
@@ -710,6 +712,47 @@ def dequantize_blockwise(
return
out
def
get_4bit_type
(
typename
,
device
=
None
,
blocksize
=
64
):
if
device
is
None
:
device
=
'cuda'
data
=
None
if
typename
==
'nf4'
:
data
=
[
-
1.0
,
-
0.6961928009986877
,
-
0.5250730514526367
,
-
0.39491748809814453
,
-
0.28444138169288635
,
-
0.18477343022823334
,
-
0.09105003625154495
,
0.0
,
0.07958029955625534
,
0.16093020141124725
,
0.24611230194568634
,
0.33791524171829224
,
0.44070982933044434
,
0.5626170039176941
,
0.7229568362236023
,
1.0
]
elif
typename
==
'fp4'
:
# 0b000 = 0
# 0b001 = 0.0625
# 0b010 = 8
# 0b011 = 12
# 0b100 = 4
# 0b101 = 6
# 0b110 = 2
# 0b111 = 3
data
=
[
0
,
0.0625
,
8.0
,
12.0
,
4.0
,
6.0
,
2.0
,
3.0
,
-
0
,
-
0.0625
,
-
8.0
,
-
12.0
,
-
4.0
,
-
6.0
,
-
2.0
,
-
3.0
]
elif
typename
==
'int4'
:
data
=
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
,
-
0
,
-
1
,
-
2
,
-
3
,
-
4
,
-
5
,
-
6
,
-
7
]
elif
typename
==
'af4'
:
# Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good)
# https://arxiv.org/abs/2306.06965
if
blocksize
==
64
:
data
=
[
-
1.
,
-
0.69441008
,
-
0.51243739
,
-
0.3736951
,
-
0.25607552
,
-
0.14982478
,
-
0.04934812
,
0.
,
0.04273164
,
0.12934483
,
0.21961274
,
0.31675666
,
0.42563882
,
0.55496234
,
0.72424863
,
1.
][::
-
1
]
else
:
raise
NotImplementedError
(
f
'4-bit AbnormalFloats currently only support blocksize 64.'
)
if
data
is
None
:
raise
NotImplementedError
(
f
'Typename
{
typename
}
not supported'
)
data
=
Tensor
(
data
)
data
/=
data
.
abs
().
max
()
assert
data
.
numel
()
==
16
return
data
.
to
(
device
)
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
return
quantize_4bit
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'fp4'
)
...
...
@@ -783,6 +826,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
datatype
=
get_4bit_type
(
quant_type
,
device
=
A
.
device
)
if
compress_statistics
:
offset
=
absmax
.
mean
()
absmax
-=
offset
...
...
@@ -790,9 +835,9 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
256
)
del
absmax
state
=
[
qabsmax
,
input_shape
,
A
.
dtype
,
blocksize
,
[
offset
,
state2
],
quant_type
]
state
=
[
qabsmax
,
input_shape
,
A
.
dtype
,
blocksize
,
[
offset
,
state2
],
quant_type
,
datatype
]
else
:
state
=
[
absmax
,
input_shape
,
A
.
dtype
,
blocksize
,
None
,
quant_type
]
state
=
[
absmax
,
input_shape
,
A
.
dtype
,
blocksize
,
None
,
quant_type
,
datatype
]
return
out
,
state
...
...
@@ -839,7 +884,7 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
shape
=
out
.
shape
dtype
=
out
.
dtype
else
:
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
=
quant_state
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
,
data_type
=
quant_state
if
compressed_stats
is
not
None
:
...
...
@@ -1408,13 +1453,14 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
return
sout
def
cutlass3_gemm
(
def
gemv_4bit
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
,
state
=
None
state
=
None
,
storage_type
=
'nf4'
):
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if
state
is
None
:
...
...
@@ -1491,8 +1537,6 @@ def cutlass3_gemm(
ldb
=
sA
[
2
]
ldc
=
m
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
# B^T @ A^T = C^T
# [km, nk -> mn]
#lda = ldb = ldc = 1
...
...
@@ -1514,15 +1558,11 @@ def cutlass3_gemm(
if
B
.
dtype
==
torch
.
uint8
:
if
A
.
dtype
==
torch
.
float16
:
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
elif
A
.
dtype
==
torch
.
bfloat16
:
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
elif
A
.
dtype
==
torch
.
float32
:
lib
.
cgemm_host_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
)
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cgemm_host_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
)
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
...
...
csrc/kernels.cu
View file @
4b88d69d
...
...
@@ -3520,7 +3520,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
}
#define num_values_4bit 32
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
// per threadblock:
...
...
@@ -3568,7 +3568,9 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
{
#pragma unroll
for
(
int
j
=
0
;
j
<
(
num_values_8bit
);
j
++
)
if
((
inner_idx
/
2
)
+
j
<
K
)
if
((
inner_idx_halved
)
+
j
<
K
)
local_B_4bit
[
j
]
=
B
[
offset_B
+
inner_idx_halved
+
j
];
else
local_B_4bit
[
j
]
=
0b01110111
;
}
}
...
...
@@ -3578,6 +3580,9 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
{
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
//if(threadIdx.x == 0)
//printf("%f %f %f %f\n", (float)local_B[k*2], (float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax, (float)local_B[k*2]- ((float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax), (float)local_absmax);
}
if
(
inner_idx
+
num_values_4bit
)
...
...
@@ -3773,8 +3778,8 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
__nv_bfloat16
,
128
>(
int
M
,
int
N
,
int
K
,
__nv_bfloat16
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
__nv_bfloat16
,
128
>(
int
M
,
int
N
,
int
K
,
__nv_bfloat16
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
...
...
csrc/kernels.cuh
View file @
4b88d69d
...
...
@@ -125,7 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
FUNC
>
__global__
void
kfunc
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
...
...
csrc/ops.cu
View file @
4b88d69d
...
...
@@ -729,12 +729,12 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
template
<
typename
T
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
int
num_blocks
=
(
m
+
3
)
/
4
;
kgemm_4bit_inference_naive
<
T
,
128
><<<
num_blocks
,
128
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
kgemm_4bit_inference_naive
<
T
,
128
><<<
num_blocks
,
128
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
)
...
...
@@ -757,8 +757,8 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
template
void
func
<
float
,
_MUL
>(
float
*
A
,
float
*
B
,
float
value
,
long
n
);
template
void
gemm_4bit_inference
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
__nv_bfloat16
>(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
__nv_bfloat16
>(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
...
...
csrc/ops.cuh
View file @
4b88d69d
...
...
@@ -200,7 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
...
...
csrc/pythonInterface.c
View file @
4b88d69d
...
...
@@ -28,11 +28,11 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_fp16
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_fp16
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
__nv_bfloat16
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
__nv_bfloat16
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
...
...
@@ -394,11 +394,11 @@ extern "C"
CMAKE_ELEMENTWISE_FUNC
(
arange
,
fp32
,
float
,
ARANGE
)
CMAKE_ELEMENTWISE_FUNC
(
_mul
,
fp32
,
float
,
_MUL
)
void
cgemm_4bit_inference_naive_fp16
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
cgemm_4bit_inference_naive_fp16
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
cgemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
cgemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
...
...
tests/test_functional.py
View file @
4b88d69d
...
...
@@ -1816,7 +1816,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
F
.
cutlass3_gemm
(
A
,
B_nf4
.
t
(),
state
=
state_nf4
)
F
.
gemv_4bit
(
A
,
B_nf4
.
t
(),
state
=
state_nf4
)
# warmup
for
i
in
range
(
iters
):
...
...
@@ -1849,7 +1849,7 @@ def test_bench_matmul(batch, seq, model, hidden):
t0
=
time
.
time
()
for
i
in
range
(
iters
):
#bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
F
.
cutlass3_gemm
(
A
,
B_nf4
.
t
(),
state
=
state_nf4
)
F
.
gemv_4bit
(
A
,
B_nf4
.
t
(),
state
=
state_nf4
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
...
...
@@ -2351,76 +2351,14 @@ def test_normal_map_tree():
print
(
pivots
)
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_cutlass3_gemm
(
dtype
):
debug
=
True
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
for
dim
in
[
4096
]:
#for dim in [128+1]:
errs
=
[]
relerrs
=
[]
max_err
=
0
max_relerr
=
0
for
i
in
range
(
100
):
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print(A)
#print(B.t())
#A[:, :-1] = 0
#B[:, :-1] = 0
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
cutlass3_gemm
(
A
,
B
.
t
())
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err
=
torch
.
abs
(
C1
-
C2
)
mag
=
torch
.
abs
(
C1
)
+
1e-8
relerr
=
err
/
mag
max_err
=
max
(
err
.
max
(),
max_err
)
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
err
=
err
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
errs
.
append
(
err
)
relerrs
.
append
(
relerr
)
#if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('')
# print(i, err, relerr)
# print(A.flatten()[-6:])
# print(B.flatten()[-6:])
# out = A.flatten()[-6:]*B.flatten()[-6:]
# print(out)
# print(out[:-1].sum())
# print('='*80)
# print(C1.flatten()[-6:])
# print(C2.flatten()[-6:])
# #assert False, 'ERROR'
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
not
debug
)
#print(c/math.sqrt(dim))
print
(
''
)
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
(
max_err
.
item
(),
max_relerr
.
item
()))
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
'fp16'
,
'bf16'
])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
def
test_gem
m
_4bit
(
dtype
):
def
test_gem
v
_4bit
(
dtype
):
print
(
''
)
#for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
for
dim
in
[
4
*
1024
]:
for
dim
in
[
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]:
#for dim in [4*1024]:
#for dim in [1*16]:
errs
=
[]
relerrs
=
[]
max_err
=
0
...
...
@@ -2446,9 +2384,10 @@ def test_gemm_4bit(dtype):
qB
,
state
=
F
.
quantize_nf4
(
B
)
F
.
dequantize_nf4
(
qB
,
state
)
#C2 = bnb.matmul_4bit(A, qB.t(), state)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
A
.
requires_grad
=
True
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
#print(state)
#print(qB)
...
...
@@ -2457,8 +2396,7 @@ def test_gemm_4bit(dtype):
#print(A)
#print(B)
#print('='*89)
#print(C1)
#print(C2)
#print(C3.flatten()[-20:])
#print(C3)
#print(C1.shape, C2.shape)
...
...
@@ -2485,10 +2423,16 @@ def test_gemm_4bit(dtype):
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
#print(dim, (max_err.item(), max_relerr.item()))
print
(
C1
.
flatten
()[
-
20
:])
print
(
C2
.
flatten
()[
-
20
:])
print
(
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
,
0.00015
)
print
(
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
,
0.0015
)
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
0.011
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.15
if
dtype
==
torch
.
float16
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
5e-5
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.0005
else
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
3e-4
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.003
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_managed
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment