Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
21723f79
"src/git@developer.sourcefind.cn:OpenDAS/llama-factory.git" did not exist on "428c58134064eaa3175d552fcc76caee5b4bd950"
Commit
21723f79
authored
Apr 29, 2023
by
Tim Dettmers
Browse files
4-bit draft.
parent
cad83994
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
273 additions
and
27 deletions
+273
-27
bitsandbytes/functional.py
bitsandbytes/functional.py
+18
-4
csrc/kernels.cu
csrc/kernels.cu
+201
-21
csrc/kernels.cuh
csrc/kernels.cuh
+1
-0
csrc/ops.cu
csrc/ops.cu
+18
-0
csrc/ops.cuh
csrc/ops.cuh
+1
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+6
-0
tests/test_functional.py
tests/test_functional.py
+28
-2
No files found.
bitsandbytes/functional.py
View file @
21723f79
...
...
@@ -1380,10 +1380,15 @@ def cutlass3_gemm(
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
,
state
=
None
):
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
A
.
dtype
)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if
state
is
None
:
Bshape
=
B
.
shape
else
:
Bshape
=
state
[
1
]
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
A
.
dtype
,
device
=
A
.
device
)
out
=
torch
.
zeros
(
size
=
(
A
.
shape
[
0
],
Bshape
[
1
])
,
dtype
=
A
.
dtype
,
device
=
A
.
device
)
sA
=
A
.
shape
sB
=
B
.
shape
...
...
@@ -1456,7 +1461,13 @@ def cutlass3_gemm(
# [km, nk -> mn]
#lda = ldb = ldc = 1
#lda = 1
#print(m, n, k, lda, ldb, ldc)
if
state
is
not
None
:
m
=
Bshape
[
0
]
k
=
Bshape
[
1
]
lda
=
Bshape
[
1
]
ldc
=
Bshape
[
0
]
ldb
=
(
ldb
+
1
)
//
2
print
(
m
,
n
,
k
,
lda
,
ldb
,
ldc
)
is_on_gpu
([
B
,
A
,
out
])
m
=
ct
.
c_int32
(
m
)
n
=
ct
.
c_int32
(
n
)
...
...
@@ -1464,7 +1475,10 @@ def cutlass3_gemm(
lda
=
ct
.
c_int32
(
lda
)
ldb
=
ct
.
c_int32
(
ldb
)
ldc
=
ct
.
c_int32
(
ldc
)
if
A
.
dtype
==
torch
.
float32
:
if
B
.
dtype
==
torch
.
uint8
:
lib
.
cgemm_4bit_inference
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
elif
A
.
dtype
==
torch
.
float32
:
lib
.
cgemm_host_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
)
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cgemm_host_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
)
...
...
csrc/kernels.cu
View file @
21723f79
...
...
@@ -69,6 +69,27 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax)
}
}
__device__
float
d2DequantizeFP4
(
unsigned
char
val
)
{
float
sign
=
(
val
&
0b1000
)
==
8
?
-
1.0
f
:
1.0
f
;
if
((
val
&
0b0110
)
==
0
)
{
// subnormal
if
((
val
&
0b0001
)
==
0
)
return
0.0
f
;
else
return
sign
*
0.0625
f
;
}
else
{
// normal
float
exponent
=
((
val
&
0b0100
)
==
4
?
2.0
f
:
8.0
f
)
+
((
val
&
0b0010
)
==
2
?
0.0
f
:
2.0
f
);
float
fraction
=
(
val
&
0b0001
)
==
1
?
1.5
f
:
1.0
f
;
return
sign
*
exponent
*
fraction
;
}
}
__device__
float
dDequantizeFP4Tree
(
unsigned
char
val
,
float
absmax
)
{
float
sign
=
(
val
&
0b1000
)
==
8
?
-
1.0
f
:
1.0
f
;
...
...
@@ -145,7 +166,61 @@ __device__ unsigned char dQuantizeFP4(float x)
return
0b0000
+
sign
;
}
__device__
float
dDequantizeNF4
(
unsigned
char
val
,
float
absmax
)
__device__
half
dhDequantizeNF4
(
unsigned
char
val
)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if
((
val
&
0b1000
)
==
8
)
if
((
val
&
0b0100
)
==
4
)
// 1
if
((
val
&
0b0010
)
==
2
)
// 11
if
((
val
&
0b0001
)
==
1
)
// 111
return
1.0
f
;
else
return
0.7229568362236023
f
;
else
if
((
val
&
0b0001
)
==
1
)
// 110
return
0.5626170039176941
f
;
else
return
0.44070982933044434
f
;
else
if
((
val
&
0b0010
)
==
2
)
//10
if
((
val
&
0b0001
)
==
1
)
// 101
return
0.33791524171829224
f
;
else
return
0.24611230194568634
f
;
else
if
((
val
&
0b0001
)
==
1
)
// 100
return
0.16093020141124725
f
;
else
return
0.07958029955625534
f
;
else
if
((
val
&
0b0100
)
==
4
)
// 0
if
((
val
&
0b0010
)
==
2
)
//01
if
((
val
&
0b0001
)
==
1
)
// 011
return
0.0
f
;
else
return
-
0.09105003625154495
f
;
else
if
((
val
&
0b0001
)
==
1
)
// 010
return
-
0.18477343022823334
f
;
else
return
-
0.28444138169288635
f
;
else
if
((
val
&
0b0010
)
==
2
)
//00
if
((
val
&
0b0001
)
==
1
)
// 001
return
-
0.39491748809814453
f
;
else
return
-
0.5250730514526367
f
;
else
if
((
val
&
0b0001
)
==
1
)
// 000
return
-
0.6961928009986877
f
;
else
return
-
1.0
f
;
}
__device__
float
dDequantizeNF4
(
unsigned
char
val
)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
...
...
@@ -153,49 +228,49 @@ __device__ float dDequantizeNF4(unsigned char val, float absmax)
if
((
val
&
0b0100
)
==
4
)
// 1
if
((
val
&
0b0010
)
==
2
)
// 11
if
((
val
&
0b0001
)
==
1
)
// 111
return
1.0
f
*
absmax
;
return
1.0
f
;
else
return
0.7229568362236023
f
*
absmax
;
return
0.7229568362236023
f
;
else
if
((
val
&
0b0001
)
==
1
)
// 110
return
0.5626170039176941
f
*
absmax
;
return
0.5626170039176941
f
;
else
return
0.44070982933044434
f
*
absmax
;
return
0.44070982933044434
f
;
else
if
((
val
&
0b0010
)
==
2
)
//10
if
((
val
&
0b0001
)
==
1
)
// 101
return
0.33791524171829224
f
*
absmax
;
return
0.33791524171829224
f
;
else
return
0.24611230194568634
f
*
absmax
;
return
0.24611230194568634
f
;
else
if
((
val
&
0b0001
)
==
1
)
// 100
return
0.16093020141124725
f
*
absmax
;
return
0.16093020141124725
f
;
else
return
0.07958029955625534
f
*
absmax
;
return
0.07958029955625534
f
;
else
if
((
val
&
0b0100
)
==
4
)
// 0
if
((
val
&
0b0010
)
==
2
)
//01
if
((
val
&
0b0001
)
==
1
)
// 011
return
0.0
f
*
absmax
;
return
0.0
f
;
else
return
-
0.09105003625154495
f
*
absmax
;
return
-
0.09105003625154495
f
;
else
if
((
val
&
0b0001
)
==
1
)
// 010
return
-
0.18477343022823334
f
*
absmax
;
return
-
0.18477343022823334
f
;
else
return
-
0.28444138169288635
f
*
absmax
;
return
-
0.28444138169288635
f
;
else
if
((
val
&
0b0010
)
==
2
)
//00
if
((
val
&
0b0001
)
==
1
)
// 001
return
-
0.39491748809814453
f
*
absmax
;
return
-
0.39491748809814453
f
;
else
return
-
0.5250730514526367
f
*
absmax
;
return
-
0.5250730514526367
f
;
else
if
((
val
&
0b0001
)
==
1
)
// 000
return
-
0.6961928009986877
f
*
absmax
;
return
-
0.6961928009986877
f
;
else
return
-
1.0
f
*
absmax
;
return
-
1.0
f
;
}
...
...
@@ -800,8 +875,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
{
vals
[
j
*
2
]
=
dDequantizeNF4
(
qvals
[
j
]
>>
4
,
local_abs_max
)
;
vals
[
j
*
2
+
1
]
=
dDequantizeNF4
(
qvals
[
j
]
&
0x0F
,
local_abs_max
)
;
vals
[
j
*
2
]
=
dDequantizeNF4
(
qvals
[
j
]
>>
4
)
*
local_abs_max
;
vals
[
j
*
2
+
1
]
=
dDequantizeNF4
(
qvals
[
j
]
&
0x0F
)
*
local_abs_max
;
}
break
;
}
...
...
@@ -2947,7 +3022,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
//// 9. write outputs to matmul output matrix
//}
template
<
typename
T
,
typename
TCAST
,
int
ITEMS
>
__device__
inline
void
vector_load
(
T
*
local
,
T
*
__restrict__
const
buffer
,
int
idx
,
int
limit_base
,
int
limit
)
template
<
typename
T
,
typename
TCAST
,
int
ITEMS
>
__device__
inline
void
vector_load
(
T
*
local
,
T
*
__restrict__
const
buffer
,
int
idx
,
int
limit_base
,
int
limit
,
float
zero_value
=
0.0
f
)
{
if
(
limit_base
+
ITEMS
<=
limit
)
reinterpret_cast
<
TCAST
*>
(
local
)[
0
]
=
reinterpret_cast
<
TCAST
*>
(
buffer
)[
idx
/
ITEMS
];
...
...
@@ -2958,7 +3033,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
if
(
limit_base
+
k
<
limit
)
local
[
k
]
=
buffer
[
idx
+
k
];
else
local
[
k
]
=
0.0
f
;
local
[
k
]
=
(
T
)
zero_value
;
}
}
}
...
...
@@ -3024,6 +3099,109 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
out
[
col_offset
+
threadIdx
.
x
]
=
smem_C
[
threadIdx
.
x
];
}
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
typedef
cub
::
BlockReduce
<
T
,
THREADS
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
reduce
;
int
col_offset
=
blockIdx
.
x
*
8
;
T
local_A
[
32
];
unsigned
char
local_B_4bit
[
16
];
T
local_B
[
32
];
T
local_C
[
8
];
__shared__
T
smem_C
[
8
];
if
(
threadIdx
.
x
<
8
)
smem_C
[
threadIdx
.
x
]
=
T
(
0
);
__syncthreads
();
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
local_C
[
k
]
=
T
(
0
);
for
(
int
idx
=
threadIdx
.
x
*
32
;
idx
<
K
;
idx
+=
blockDim
.
x
*
32
)
{
// we load only 8 values per iteration from A, so we
// need to do 4 loads for every single load from B
// for B, we have packed values, so the 16 8-bit values
// turn into 32 4-bit values to 4x 4 loads turns into 4x 8 loads
vector_load
<
T
,
int4
,
8
>
(
local_A
,
A
,
idx
,
idx
,
K
);
vector_load
<
T
,
int4
,
8
>
(
&
(
local_A
[
8
]),
A
,
idx
+
8
,
idx
+
8
,
K
);
vector_load
<
T
,
int4
,
8
>
(
&
(
local_A
[
16
]),
A
,
idx
+
16
,
idx
+
16
,
K
);
vector_load
<
T
,
int4
,
8
>
(
&
(
local_A
[
24
]),
A
,
idx
+
24
,
idx
+
24
,
K
);
for
(
int
col
=
0
;
col
<
8
;
col
++
)
{
if
((
col
+
col_offset
)
>=
M
){
break
;
}
int
offset_B
=
(
col_offset
+
col
)
*
ldb
;
// 0111 -> 0.0f in NF4
// since we have packed 8-bits, we need cat(0b0111, 0b0111) = 0b01110111
vector_load
<
unsigned
char
,
int4
,
16
>
(
local_B_4bit
,
B
,
(
offset_B
+
idx
+
1
)
/
2
,
(
idx
+
1
)
/
2
,
(
K
+
1
)
/
2
,
0b01110111
);
int
absidx
=
(
idx
+
offset_B
)
/
blocksize
;
half
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
//for(int k = 0; k < 16; k++)
//printf("%i %i ", local_B_4bit[k] >> 4, local_B_4bit[k] & 0x0F);
//printf("\n");
//vector_load<T, int4, 8>(local_A, A, idx, idx, K);
#pragma unroll 16
for
(
int
k
=
0
;
k
<
16
;
k
++
)
{
//if(local_B_4bit[k ] != 0b01110111)
//printf("(%i %i %i) %i -> %f, %i -> %f\n", threadIdx.x , k, K, local_B_4bit[k ] >> 4, dDequantizeNF4(local_B_4bit[k ] >> 4, local_absmax),
//local_B_4bit[k ] & 0x0F, dDequantizeNF4(local_B_4bit[k ] & 0x0F, local_absmax));
//local_B[k*2] = d2DequantizeFP4(local_B_4bit[k] >> 4);//*local_absmax;
//local_B[k*2 + 1] = d2DequantizeFP4(local_B_4bit[k] & 0x0F);//*local_absmax;
local_B
[
k
*
2
]
=
(
half
)(
local_B_4bit
[
k
]
>>
4
)
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
(
half
)(
local_B_4bit
[
k
]
&
0x0F
)
*
local_absmax
;
//local_B[k*2] = (half)dDequantizeNF4(local_B_4bit[k ] >> 4);//*local_absmax;
//local_B[k*2 + 1] = (half)dDequantizeNF4(local_B_4bit[k ] & 0x0F);//*local_absmax;
}
#pragma unroll 32
//for(int k = 0; k < 8; k++)
for
(
int
k
=
0
;
k
<
32
;
k
++
)
{
local_C
[
col
]
+=
local_A
[
k
]
*
local_B
[
k
];
//if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
//if((float)local_B[k] != 0.0)
//printf("%i %i %i %i %f*%f\n", threadIdx.x, k, col, (float)local_A[k], (float)local_B[k]);
}
}
}
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
{
local_C
[
k
]
=
BlockReduce
(
reduce
).
Reduce
(
local_C
[
k
],
cub
::
Sum
());
__syncthreads
();
}
if
(
threadIdx
.
x
==
0
)
{
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
smem_C
[
k
]
=
local_C
[
k
];
}
else
if
(
threadIdx
.
x
>=
32
)
// early return for unused warps
return
;
__syncwarp
();
if
(
threadIdx
.
x
<
8
&&
col_offset
+
threadIdx
.
x
<
M
)
out
[
col_offset
+
threadIdx
.
x
]
=
smem_C
[
threadIdx
.
x
];
}
//#define ROWS 2
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
//{
...
...
@@ -3207,6 +3385,8 @@ template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half *
template
__global__
void
gemm_device
<
float
,
32
,
128
>(
int
M
,
int
N
,
int
K
,
float
*
__restrict__
const
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
16
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
kgemm_4bit_inference
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
template
__global__
void
with_staging_unified
<
2
>(
float
const
*
global_in
,
float
*
global_out
,
size_t
size
,
size_t
batch_sz
);
...
...
csrc/kernels.cuh
View file @
21723f79
...
...
@@ -139,5 +139,6 @@ template <size_t stages_count /* Pipeline with stages_count stages */>
__global__
void
with_staging_unified
(
float
const
*
global_in
,
float
*
global_out
,
size_t
size
,
size_t
batch_sz
);
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
#endif
csrc/ops.cu
View file @
21723f79
...
...
@@ -695,10 +695,28 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
gemm_device
<
T
,
16
,
128
><<<
num_blocks
,
dimBlock
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
dim3
dimBlock
(
128
);
int
num_blocks
=
(
m
+
7
)
/
8
;
cout
<<
num_blocks
<<
endl
;
cout
<<
lda
<<
endl
;
cout
<<
ldb
<<
endl
;
cout
<<
ldc
<<
endl
;
cout
<<
m
<<
endl
;
cout
<<
n
<<
endl
;
cout
<<
k
<<
endl
;
kgemm_4bit_inference
<
T
,
128
><<<
num_blocks
,
dimBlock
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template
void
gemm_4bit_inference
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_host
<
float
>(
int
m
,
int
n
,
int
k
,
float
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
...
...
csrc/ops.cuh
View file @
21723f79
...
...
@@ -191,6 +191,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
void
matmul4bite
(
half
*
A
,
unsigned
char
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
rowsA
,
int
colsA
,
int
colsB
);
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
void
pipeline_test
(
float
*
A
,
float
*
B
,
size_t
n
,
size_t
batch_size
);
...
...
csrc/pythonInterface.c
View file @
21723f79
...
...
@@ -25,6 +25,9 @@ void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, in
void
gemm_host_fp16
(
int
M
,
int
N
,
int
K
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host
<
half
>
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
,
16
);
}
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
void fname##32bit_g##gbits(gtype *g, gtype *p, \
...
...
@@ -319,6 +322,9 @@ extern "C"
void
cgemm_host_fp16
(
int
M
,
int
N
,
int
K
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host_fp16
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
void
cgemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
...
...
tests/test_functional.py
View file @
21723f79
...
...
@@ -2352,8 +2352,8 @@ def test_normal_map_tree():
print
(
pivots
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
'fp32'
,
'fp16'
])
#
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
#
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_cutlass3_gemm
(
dtype
):
for
i
in
range
(
1
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
...
...
@@ -2373,6 +2373,32 @@ def test_cutlass3_gemm(dtype):
torch
.
testing
.
assert_close
(
C1
,
C2
,
atol
=
1e-05
,
rtol
=
0.005
)
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_gemm_4bit
(
dtype
):
for
i
in
range
(
1
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#torch.random.manual_seed(17)
A
=
torch
.
rand
(
1
,
4096
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
rand
(
4
*
4096
,
4096
,
dtype
=
dtype
,
device
=
'cuda'
)
#print('')
#print(A)
#print(B)
qB
,
state
=
F
.
quantize_nf4
(
B
)
F
.
dequantize_nf4
(
qB
,
state
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
#C1 = bnb.matmul_4bit(A, qB.t(), state)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
#print(C1)
#print(C2)
#torch.testing.assert_close(C1, C2, atol=1e-5, rtol=0.005)
def
test_pipeline_func
():
a
=
torch
.
rand
(
2
,
4
).
cuda
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment