Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
5fab6734
Commit
5fab6734
authored
Jul 09, 2023
by
Tim Dettmers
Browse files
Added fp32 compute type for gemv_4bit.
parent
cef519c8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
56 additions
and
105 deletions
+56
-105
bitsandbytes/functional.py
bitsandbytes/functional.py
+15
-83
csrc/kernels.cu
csrc/kernels.cu
+22
-9
csrc/kernels.cuh
csrc/kernels.cuh
+1
-1
csrc/ops.cu
csrc/ops.cu
+6
-4
csrc/ops.cuh
csrc/ops.cuh
+1
-1
csrc/pythonInterface.c
csrc/pythonInterface.c
+8
-2
tests/test_functional.py
tests/test_functional.py
+3
-5
No files found.
bitsandbytes/functional.py
View file @
5fab6734
...
@@ -1464,6 +1464,9 @@ def gemv_4bit(
...
@@ -1464,6 +1464,9 @@ def gemv_4bit(
if
state
is
None
:
if
state
is
None
:
raise
ValueError
(
f
'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )'
)
raise
ValueError
(
f
'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )'
)
if
A
.
numel
()
!=
A
.
shape
[
-
1
]:
raise
ValueError
(
f
'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]'
)
Bshape
=
state
[
1
]
Bshape
=
state
[
1
]
bout
=
Bshape
[
0
]
bout
=
Bshape
[
0
]
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
,
data_type
=
state
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
,
data_type
=
state
...
@@ -1474,90 +1477,17 @@ def gemv_4bit(
...
@@ -1474,90 +1477,17 @@ def gemv_4bit(
if
out
is
None
:
if
out
is
None
:
if
len
(
A
.
shape
)
==
3
:
if
len
(
A
.
shape
)
==
3
:
out
=
torch
.
zeros
(
size
=
(
A
.
shape
[
0
],
A
.
shape
[
1
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
out
=
torch
.
empty
(
size
=
(
A
.
shape
[
0
],
A
.
shape
[
1
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
out
=
torch
.
zeros
(
size
=
(
A
.
shape
[
0
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
sA
=
A
.
shape
sB
=
B
.
shape
if
transposed_A
and
len
(
sA
)
==
2
:
sA
=
(
sA
[
1
],
sA
[
0
])
elif
transposed_A
and
len
(
sA
)
==
3
:
sA
=
(
sA
[
0
],
sA
[
2
],
sA
[
0
])
if
transposed_B
and
len
(
sB
)
==
2
:
sB
=
(
sB
[
1
],
sB
[
0
])
elif
transposed_B
and
len
(
sB
)
==
3
:
sB
=
(
sB
[
0
],
sB
[
2
],
sB
[
0
])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if
len
(
sB
)
==
2
:
if
B
.
stride
()[
0
]
==
B
.
shape
[
1
]:
transposed_B
=
False
elif
B
.
stride
()[
1
]
==
B
.
shape
[
0
]:
transposed_B
=
True
if
len
(
A
.
shape
)
==
2
:
if
A
.
stride
()[
0
]
==
A
.
shape
[
1
]:
transposed_A
=
False
elif
A
.
stride
()[
1
]
==
A
.
shape
[
0
]:
transposed_A
=
True
else
:
else
:
if
A
.
stride
()[
1
]
==
A
.
shape
[
2
]:
out
=
torch
.
empty
(
size
=
(
A
.
shape
[
0
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
transposed_A
=
False
elif
A
.
stride
()[
2
]
==
A
.
shape
[
1
]:
n
=
1
transposed_A
=
True
m
=
Bshape
[
0
]
k
=
Bshape
[
1
]
if
len
(
sA
)
==
2
:
lda
=
Bshape
[
0
]
n
=
sA
[
0
]
ldc
=
Bshape
[
0
]
ldb
=
A
.
stride
()[
1
if
transposed_A
else
0
]
ldb
=
(
A
.
shape
[
-
1
]
+
1
)
//
2
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
is_on_gpu
([
B
,
A
,
out
,
absmax
,
state
[
-
1
]])
n
=
sA
[
0
]
*
sA
[
1
]
ldb
=
sA
[
2
]
m
=
sB
[
1
]
k
=
sB
[
0
]
lda
=
B
.
stride
()[
0
]
ldc
=
sB
[
1
]
elif
len
(
sB
)
==
3
:
# special case
assert
len
(
sA
)
==
3
if
not
(
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]):
raise
ValueError
(
f
"Only bsi,bso->io supported for tensor contractions, but dims for A x B were:
{
sA
}
x
{
sB
}
"
)
transposed_A
=
True
transposed_B
=
False
m
=
sB
[
2
]
n
=
sA
[
2
]
k
=
sB
[
0
]
*
sB
[
1
]
lda
=
n
ldb
=
sA
[
2
]
ldc
=
m
# B^T @ A^T = C^T
# [km, nk -> mn]
#lda = ldb = ldc = 1
#lda = 1
if
state
is
not
None
:
m
=
Bshape
[
0
]
k
=
Bshape
[
1
]
lda
=
Bshape
[
0
]
ldc
=
Bshape
[
0
]
ldb
=
(
ldb
+
1
)
//
2
#print(m, n, k, lda, ldb, ldc)
is_on_gpu
([
B
,
A
,
out
])
m
=
ct
.
c_int32
(
m
)
m
=
ct
.
c_int32
(
m
)
n
=
ct
.
c_int32
(
n
)
n
=
ct
.
c_int32
(
n
)
k
=
ct
.
c_int32
(
k
)
k
=
ct
.
c_int32
(
k
)
...
@@ -1570,6 +1500,8 @@ def gemv_4bit(
...
@@ -1570,6 +1500,8 @@ def gemv_4bit(
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
elif
A
.
dtype
==
torch
.
bfloat16
:
elif
A
.
dtype
==
torch
.
bfloat16
:
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
elif
A
.
dtype
==
torch
.
float32
:
lib
.
cgemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
else
:
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
else
:
else
:
...
...
csrc/kernels.cu
View file @
5fab6734
...
@@ -3520,7 +3520,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3520,7 +3520,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
}
}
#define num_values_4bit 32
#define num_values_4bit 32
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
,
int
THREADS
,
int
BITS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
{
// per threadblock:
// per threadblock:
...
@@ -3528,7 +3528,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3528,7 +3528,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
// 4 warps -> 4 loads per iter
// 4 warps -> 4 loads per iter
// 1x128 * 128x4 -> 1x4 outputs
// 1x128 * 128x4 -> 1x4 outputs
//typedef cub::WarpReduce<T> WarpReduce;
typedef
cub
::
WarpReduce
<
float
>
WarpReduce
;
typedef
cub
::
WarpReduce
<
float
>
WarpReduce
;
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
THREADS
/
32
];
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
THREADS
/
32
];
...
@@ -3536,7 +3535,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3536,7 +3535,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
row_B
=
(
THREADS
/
32
)
*
blockIdx
.
x
+
warp_idx
;
const
int
row_B
=
(
THREADS
/
32
)
*
blockIdx
.
x
+
warp_idx
;
const
int
num_values_8bit
=
num_values_4bit
/
2
;
const
int
num_values_8bit
=
num_values_4bit
/
2
;
//T local_C = T(0.0f);
float
local_C
=
0.0
f
;
float
local_C
=
0.0
f
;
unsigned
char
local_B_4bit
[
num_values_8bit
];
unsigned
char
local_B_4bit
[
num_values_8bit
];
...
@@ -3585,10 +3583,24 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3585,10 +3583,24 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
if
(
inner_idx
+
num_values_4bit
)
if
(
inner_idx
+
num_values_4bit
)
{
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
0
];
if
(
BITS
==
16
)
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
1
];
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
3
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
3
];
}
else
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
3
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
4
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
4
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
5
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
5
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
6
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
6
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
7
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
7
];
}
}
}
else
else
...
@@ -3776,8 +3788,9 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
...
@@ -3776,8 +3788,9 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
,
16
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
__nv_bfloat16
,
128
>(
int
M
,
int
N
,
int
K
,
__nv_bfloat16
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
__nv_bfloat16
,
128
,
16
>(
int
M
,
int
N
,
int
K
,
__nv_bfloat16
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
float
,
128
,
32
>(
int
M
,
int
N
,
int
K
,
float
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
...
...
csrc/kernels.cuh
View file @
5fab6734
...
@@ -125,7 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
...
@@ -125,7 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
THREADS
,
int
BITS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
FUNC
>
__global__
void
kfunc
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
template
<
typename
T
,
int
FUNC
>
__global__
void
kfunc
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
...
...
csrc/ops.cu
View file @
5fab6734
...
@@ -729,12 +729,12 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
...
@@ -729,12 +729,12 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
}
template
<
typename
T
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
,
int
BITS
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
{
int
num_blocks
=
(
m
+
3
)
/
4
;
int
num_blocks
=
(
m
+
3
)
/
4
;
kgemm_4bit_inference_naive
<
T
,
128
><<<
num_blocks
,
128
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
kgemm_4bit_inference_naive
<
T
,
128
,
BITS
><<<
num_blocks
,
128
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
}
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
)
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
)
...
@@ -757,8 +757,10 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
...
@@ -757,8 +757,10 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
template
void
func
<
float
,
_MUL
>(
float
*
A
,
float
*
B
,
float
value
,
long
n
);
template
void
func
<
float
,
_MUL
>(
float
*
A
,
float
*
B
,
float
value
,
long
n
);
template
void
gemm_4bit_inference
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
half
,
16
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
__nv_bfloat16
>(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
__nv_bfloat16
,
16
>(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
float
,
32
>(
int
m
,
int
n
,
int
k
,
float
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
...
...
csrc/ops.cuh
View file @
5fab6734
...
@@ -200,7 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
...
@@ -200,7 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
BITS
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
...
...
csrc/pythonInterface.c
View file @
5fab6734
...
@@ -29,10 +29,13 @@ void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, floa
...
@@ -29,10 +29,13 @@ void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, floa
{
gemm_4bit_inference
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
{
gemm_4bit_inference
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_fp16
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
void
gemm_4bit_inference_naive_fp16
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
{
gemm_4bit_inference_naive
<
half
,
16
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
void
gemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
__nv_bfloat16
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
{
gemm_4bit_inference_naive
<
__nv_bfloat16
,
16
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_fp32
(
int
m
,
int
n
,
int
k
,
float
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
float
,
32
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
...
@@ -400,6 +403,9 @@ extern "C"
...
@@ -400,6 +403,9 @@ extern "C"
void
cgemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
void
cgemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
{
gemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
cgemm_4bit_inference_naive_fp32
(
int
m
,
int
n
,
int
k
,
float
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#endif
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
...
...
tests/test_functional.py
View file @
5fab6734
...
@@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
batch_size
=
5
batch_size
=
1
seqdim
=
1
seqdim
=
1
values
=
[]
values
=
[]
#values.append((batch_size, seqdim, 768, 4 * 768))
#values.append((batch_size, seqdim, 768, 4 * 768))
...
@@ -1793,7 +1793,7 @@ values.append((batch_size, seqdim, 6656, 4*6656))
...
@@ -1793,7 +1793,7 @@ values.append((batch_size, seqdim, 6656, 4*6656))
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
iters
=
8
0
iters
=
100
0
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
...
@@ -2361,9 +2361,7 @@ def test_normal_map_tree():
...
@@ -2361,9 +2361,7 @@ def test_normal_map_tree():
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
[
True
,
False
],
ids
=
[
'DQ_True'
,
'DQ_False'
])
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
[
True
,
False
],
ids
=
[
'DQ_True'
,
'DQ_False'
])
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'nf4'
,
'fp4'
],
ids
=
[
'nf4'
,
'fp4'
])
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'nf4'
,
'fp4'
],
ids
=
[
'nf4'
,
'fp4'
])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
[
'fp16'
,
'bf16'
,
'fp32'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
'fp16'
,
'bf16'
])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
def
test_gemv_4bit
(
dtype
,
storage_type
,
double_quant
):
def
test_gemv_4bit
(
dtype
,
storage_type
,
double_quant
):
print
(
''
)
print
(
''
)
for
dim
in
[
128
,
256
,
512
,
1024
,
2048
,
4096
]:
for
dim
in
[
128
,
256
,
512
,
1024
,
2048
,
4096
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment