Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
5fab6734
You need to sign in or sign up before continuing.
Commit
5fab6734
authored
Jul 09, 2023
by
Tim Dettmers
Browse files
Added fp32 compute type for gemv_4bit.
parent
cef519c8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
56 additions
and
105 deletions
+56
-105
bitsandbytes/functional.py
bitsandbytes/functional.py
+15
-83
csrc/kernels.cu
csrc/kernels.cu
+22
-9
csrc/kernels.cuh
csrc/kernels.cuh
+1
-1
csrc/ops.cu
csrc/ops.cu
+6
-4
csrc/ops.cuh
csrc/ops.cuh
+1
-1
csrc/pythonInterface.c
csrc/pythonInterface.c
+8
-2
tests/test_functional.py
tests/test_functional.py
+3
-5
No files found.
bitsandbytes/functional.py
View file @
5fab6734
...
...
@@ -1464,6 +1464,9 @@ def gemv_4bit(
if
state
is
None
:
raise
ValueError
(
f
'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )'
)
if
A
.
numel
()
!=
A
.
shape
[
-
1
]:
raise
ValueError
(
f
'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]'
)
Bshape
=
state
[
1
]
bout
=
Bshape
[
0
]
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
,
data_type
=
state
...
...
@@ -1474,90 +1477,17 @@ def gemv_4bit(
if
out
is
None
:
if
len
(
A
.
shape
)
==
3
:
out
=
torch
.
zeros
(
size
=
(
A
.
shape
[
0
],
A
.
shape
[
1
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
out
=
torch
.
zeros
(
size
=
(
A
.
shape
[
0
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
sA
=
A
.
shape
sB
=
B
.
shape
if
transposed_A
and
len
(
sA
)
==
2
:
sA
=
(
sA
[
1
],
sA
[
0
])
elif
transposed_A
and
len
(
sA
)
==
3
:
sA
=
(
sA
[
0
],
sA
[
2
],
sA
[
0
])
if
transposed_B
and
len
(
sB
)
==
2
:
sB
=
(
sB
[
1
],
sB
[
0
])
elif
transposed_B
and
len
(
sB
)
==
3
:
sB
=
(
sB
[
0
],
sB
[
2
],
sB
[
0
])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if
len
(
sB
)
==
2
:
if
B
.
stride
()[
0
]
==
B
.
shape
[
1
]:
transposed_B
=
False
elif
B
.
stride
()[
1
]
==
B
.
shape
[
0
]:
transposed_B
=
True
if
len
(
A
.
shape
)
==
2
:
if
A
.
stride
()[
0
]
==
A
.
shape
[
1
]:
transposed_A
=
False
elif
A
.
stride
()[
1
]
==
A
.
shape
[
0
]:
transposed_A
=
True
out
=
torch
.
empty
(
size
=
(
A
.
shape
[
0
],
A
.
shape
[
1
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
if
A
.
stride
()[
1
]
==
A
.
shape
[
2
]:
transposed_A
=
False
elif
A
.
stride
()[
2
]
==
A
.
shape
[
1
]:
transposed_A
=
True
if
len
(
sA
)
==
2
:
n
=
sA
[
0
]
ldb
=
A
.
stride
()[
1
if
transposed_A
else
0
]
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
n
=
sA
[
0
]
*
sA
[
1
]
ldb
=
sA
[
2
]
m
=
sB
[
1
]
k
=
sB
[
0
]
lda
=
B
.
stride
()[
0
]
ldc
=
sB
[
1
]
elif
len
(
sB
)
==
3
:
# special case
assert
len
(
sA
)
==
3
if
not
(
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]):
raise
ValueError
(
f
"Only bsi,bso->io supported for tensor contractions, but dims for A x B were:
{
sA
}
x
{
sB
}
"
)
transposed_A
=
True
transposed_B
=
False
m
=
sB
[
2
]
n
=
sA
[
2
]
k
=
sB
[
0
]
*
sB
[
1
]
lda
=
n
ldb
=
sA
[
2
]
ldc
=
m
# B^T @ A^T = C^T
# [km, nk -> mn]
#lda = ldb = ldc = 1
#lda = 1
if
state
is
not
None
:
m
=
Bshape
[
0
]
k
=
Bshape
[
1
]
lda
=
Bshape
[
0
]
ldc
=
Bshape
[
0
]
ldb
=
(
ldb
+
1
)
//
2
#print(m, n, k, lda, ldb, ldc)
is_on_gpu
([
B
,
A
,
out
])
out
=
torch
.
empty
(
size
=
(
A
.
shape
[
0
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
n
=
1
m
=
Bshape
[
0
]
k
=
Bshape
[
1
]
lda
=
Bshape
[
0
]
ldc
=
Bshape
[
0
]
ldb
=
(
A
.
shape
[
-
1
]
+
1
)
//
2
is_on_gpu
([
B
,
A
,
out
,
absmax
,
state
[
-
1
]])
m
=
ct
.
c_int32
(
m
)
n
=
ct
.
c_int32
(
n
)
k
=
ct
.
c_int32
(
k
)
...
...
@@ -1570,6 +1500,8 @@ def gemv_4bit(
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
elif
A
.
dtype
==
torch
.
bfloat16
:
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
elif
A
.
dtype
==
torch
.
float32
:
lib
.
cgemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
else
:
...
...
csrc/kernels.cu
View file @
5fab6734
...
...
@@ -3520,7 +3520,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
}
#define num_values_4bit 32
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
,
int
THREADS
,
int
BITS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
// per threadblock:
...
...
@@ -3528,7 +3528,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
// 4 warps -> 4 loads per iter
// 1x128 * 128x4 -> 1x4 outputs
//typedef cub::WarpReduce<T> WarpReduce;
typedef
cub
::
WarpReduce
<
float
>
WarpReduce
;
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
THREADS
/
32
];
...
...
@@ -3536,7 +3535,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
row_B
=
(
THREADS
/
32
)
*
blockIdx
.
x
+
warp_idx
;
const
int
num_values_8bit
=
num_values_4bit
/
2
;
//T local_C = T(0.0f);
float
local_C
=
0.0
f
;
unsigned
char
local_B_4bit
[
num_values_8bit
];
...
...
@@ -3585,10 +3583,24 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
if
(
inner_idx
+
num_values_4bit
)
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
3
];
if
(
BITS
==
16
)
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
3
];
}
else
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
3
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
4
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
4
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
5
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
5
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
6
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
6
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
7
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
7
];
}
}
else
...
...
@@ -3776,8 +3788,9 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
__nv_bfloat16
,
128
>(
int
M
,
int
N
,
int
K
,
__nv_bfloat16
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
,
16
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
__nv_bfloat16
,
128
,
16
>(
int
M
,
int
N
,
int
K
,
__nv_bfloat16
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
float
,
128
,
32
>(
int
M
,
int
N
,
int
K
,
float
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
...
...
csrc/kernels.cuh
View file @
5fab6734
...
...
@@ -125,7 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
THREADS
,
int
BITS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
FUNC
>
__global__
void
kfunc
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
...
...
csrc/ops.cu
View file @
5fab6734
...
...
@@ -729,12 +729,12 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
template
<
typename
T
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
,
int
BITS
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
int
num_blocks
=
(
m
+
3
)
/
4
;
kgemm_4bit_inference_naive
<
T
,
128
><<<
num_blocks
,
128
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
kgemm_4bit_inference_naive
<
T
,
128
,
BITS
><<<
num_blocks
,
128
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
)
...
...
@@ -757,8 +757,10 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
template
void
func
<
float
,
_MUL
>(
float
*
A
,
float
*
B
,
float
value
,
long
n
);
template
void
gemm_4bit_inference
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
__nv_bfloat16
>(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
half
,
16
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
__nv_bfloat16
,
16
>(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
float
,
32
>(
int
m
,
int
n
,
int
k
,
float
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
...
...
csrc/ops.cuh
View file @
5fab6734
...
...
@@ -200,7 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
BITS
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
...
...
csrc/pythonInterface.c
View file @
5fab6734
...
...
@@ -29,10 +29,13 @@ void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, floa
{
gemm_4bit_inference
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_fp16
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
{
gemm_4bit_inference_naive
<
half
,
16
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
__nv_bfloat16
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
{
gemm_4bit_inference_naive
<
__nv_bfloat16
,
16
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_fp32
(
int
m
,
int
n
,
int
k
,
float
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
float
,
32
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
...
...
@@ -400,6 +403,9 @@ extern "C"
void
cgemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
cgemm_4bit_inference_naive_fp32
(
int
m
,
int
n
,
int
k
,
float
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
...
...
tests/test_functional.py
View file @
5fab6734
...
...
@@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
batch_size
=
5
batch_size
=
1
seqdim
=
1
values
=
[]
#values.append((batch_size, seqdim, 768, 4 * 768))
...
...
@@ -1793,7 +1793,7 @@ values.append((batch_size, seqdim, 6656, 4*6656))
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
iters
=
8
0
iters
=
100
0
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
...
...
@@ -2361,9 +2361,7 @@ def test_normal_map_tree():
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
[
True
,
False
],
ids
=
[
'DQ_True'
,
'DQ_False'
])
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'nf4'
,
'fp4'
],
ids
=
[
'nf4'
,
'fp4'
])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
'fp16'
,
'bf16'
])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
[
'fp16'
,
'bf16'
,
'fp32'
])
def
test_gemv_4bit
(
dtype
,
storage_type
,
double_quant
):
print
(
''
)
for
dim
in
[
128
,
256
,
512
,
1024
,
2048
,
4096
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment