Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
264a9485
Commit
264a9485
authored
May 02, 2023
by
Tim Dettmers
Browse files
4-bit draft; 128 vector load 240.
parent
869b7e83
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
278 additions
and
136 deletions
+278
-136
bitsandbytes/functional.py
bitsandbytes/functional.py
+4
-2
csrc/kernels.cu
csrc/kernels.cu
+196
-99
csrc/ops.cu
csrc/ops.cu
+9
-9
tests/test_functional.py
tests/test_functional.py
+69
-26
No files found.
bitsandbytes/functional.py
View file @
264a9485
...
...
@@ -1385,10 +1385,12 @@ def cutlass3_gemm(
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if
state
is
None
:
Bshape
=
B
.
shape
bout
=
Bshape
[
1
]
else
:
Bshape
=
state
[
1
]
bout
=
Bshape
[
0
]
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
(
A
.
shape
[
0
],
Bshape
[
1
]
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
out
=
torch
.
zeros
(
size
=
(
A
.
shape
[
0
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
sA
=
A
.
shape
sB
=
B
.
shape
...
...
@@ -1464,7 +1466,7 @@ def cutlass3_gemm(
if
state
is
not
None
:
m
=
Bshape
[
0
]
k
=
Bshape
[
1
]
lda
=
Bshape
[
1
]
lda
=
Bshape
[
0
]
ldc
=
Bshape
[
0
]
ldb
=
(
ldb
+
1
)
//
2
#print(m, n, k, lda, ldb, ldc)
...
...
csrc/kernels.cu
View file @
264a9485
...
...
@@ -3044,22 +3044,15 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
#define WARPS 5
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
typedef
cub
::
WarpReduce
<
half
>
WarpReduce
;
// Allocate WarpReduce shared memory for one warp
//__shared__ typename WarpReduce::TempStorage temp_storage;
//typedef cub::BlockReduce<T, THREADS> BlockReduce;
//// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce;
int
col_offset
=
blockIdx
.
x
*
32
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
half_warp_id
=
threadIdx
.
x
/
16
;
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
const
int
val_per_iter
=
blockDim
.
x
-
32
;
T
local_A
[
2
];
T
local_B
[
64
];
T
local_A
[
4
];
T
local_B
[
128
];
const
int
a_tile_offset
=
16
;
const
int
b_tile_offset
=
(
16
*
32
+
16
);
...
...
@@ -3082,24 +3075,45 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
if
(
loaded_values
==
0
)
{
local_A
[
0
]
=
A
[
idx
];
local_A
[
1
]
=
A
[
idx
+
blockDim
.
x
-
32
];
local_A
[
1
]
=
A
[
idx
+
(
1
*
val_per_iter
)];
local_A
[
2
]
=
A
[
idx
+
(
2
*
val_per_iter
)];
local_A
[
3
]
=
A
[
idx
+
(
3
*
val_per_iter
)];
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
{
local_B
[
col
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
local_B
[
col
+
32
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
+
blockDim
.
x
-
32
];
local_B
[
col
+
32
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
+
(
1
*
val_per_iter
)];
local_B
[
col
+
64
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
+
(
2
*
val_per_iter
)];
local_B
[
col
+
96
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
+
(
3
*
val_per_iter
)];
}
loaded_values
=
1
;
loaded_values
=
3
;
}
else
{
local_A
[
0
]
=
local_A
[
1
];
loaded_values
--
;
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
local_B
[
col
+
32
];
if
(
loaded_values
==
3
)
{
local_A
[
0
]
=
local_A
[
1
];
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
local_B
[
col
+
(
32
)];
}
else
if
(
loaded_values
==
2
)
{
local_A
[
0
]
=
local_A
[
2
];
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
local_B
[
col
+
(
64
)];
}
else
{
local_A
[
0
]
=
local_A
[
3
];
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
local_B
[
col
+
(
96
)];
}
loaded_values
--
;
}
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
...
...
@@ -3139,26 +3153,46 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
if
(
loaded_values
==
0
)
{
local_A
[
0
]
=
A
[
idx
];
local_A
[
1
]
=
A
[
idx
+
blockDim
.
x
-
32
];
local_A
[
1
]
=
A
[
idx
+
(
1
*
val_per_iter
)];
local_A
[
2
]
=
A
[
idx
+
(
2
*
val_per_iter
)];
local_A
[
3
]
=
A
[
idx
+
(
3
*
val_per_iter
)];
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
{
local_B
[
col
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
local_B
[
col
+
32
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
+
blockDim
.
x
-
32
];
local_B
[
col
+
32
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
+
(
1
*
val_per_iter
)];
local_B
[
col
+
64
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
+
(
2
*
val_per_iter
)];
local_B
[
col
+
96
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
+
(
3
*
val_per_iter
)];
}
loaded_values
=
1
;
loaded_values
=
3
;
}
else
{
local_A
[
0
]
=
local_A
[
1
];
loaded_values
--
;
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
local_B
[
col
+
32
];
if
(
loaded_values
==
3
)
{
local_A
[
0
]
=
local_A
[
1
];
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
local_B
[
col
+
(
32
)];
}
else
if
(
loaded_values
==
2
)
{
local_A
[
0
]
=
local_A
[
2
];
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
local_B
[
col
+
(
64
)];
}
else
{
local_A
[
0
]
=
local_A
[
3
];
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
local_B
[
col
+
(
96
)];
}
loaded_values
--
;
}
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
...
...
@@ -3215,104 +3249,166 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
typedef
cub
::
BlockReduce
<
T
,
THREADS
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
reduce
;
int
col_offset
=
blockIdx
.
x
*
8
;
T
local_A
[
32
];
unsigned
char
local_B_4bit
[
16
];
T
local_B
[
32
];
T
local_C
[
8
];
int
col_offset
=
blockIdx
.
x
*
32
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
half_warp_id
=
threadIdx
.
x
/
16
;
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
__shared__
T
smem_C
[
8
];
T
local_A
[
2
];
T
local_B
[
64
];
unsigned
char
local_B_4bit
[
32
];
if
(
threadIdx
.
x
<
8
)
smem_C
[
threadIdx
.
x
]
=
T
(
0
);
__syncthreads
();
const
int
a_tile_offset
=
16
;
const
int
b_tile_offset
=
(
16
*
32
+
16
);
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
local_C
[
k
]
=
T
(
0
)
;
__shared__
T
smem_A
[
8
*
16
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_B
[
2
*
batch_size_warps
*
16
*
32
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
//__shared__ T smem_C[8*32]
;
wmma
::
fragment
<
wmma
::
matrix_a
,
8
,
32
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
8
,
32
,
16
,
half
,
wmma
::
col_major
>
b_frag
;
wmma
::
fragment
<
wmma
::
accumulator
,
8
,
32
,
16
,
half
>
c_frag
;
wmma
::
fill_fragment
(
c_frag
,
0.0
f
);
for
(
int
idx
=
threadIdx
.
x
*
32
;
idx
<
K
;
idx
+=
blockDim
.
x
*
32
)
int
ticktock
=
0
;
int
idx
=
0
+
threadIdx
.
x
;
int
loaded_values
=
0
;
// prefetch
if
(
idx
<
K
&&
warp_id
<
(
WARPS
-
1
))
{
if
(
loaded_values
==
0
)
{
local_A
[
0
]
=
A
[
idx
];
local_A
[
1
]
=
A
[
idx
+
blockDim
.
x
-
32
];
// we load only 8 values per iteration from A, so we
// need to do 4 loads for every single load from B
// for B, we have packed values, so the 16 8-bit values
// turn into 32 4-bit values to 4x 4 loads turns into 4x 8 loads
vector_load
<
T
,
int4
,
8
>
(
local_A
,
A
,
idx
,
idx
,
K
);
vector_load
<
T
,
int4
,
8
>
(
&
(
local_A
[
8
]),
A
,
idx
+
8
,
idx
+
8
,
K
);
vector_load
<
T
,
int4
,
8
>
(
&
(
local_A
[
16
]),
A
,
idx
+
16
,
idx
+
16
,
K
);
vector_load
<
T
,
int4
,
8
>
(
&
(
local_A
[
24
]),
A
,
idx
+
24
,
idx
+
24
,
K
);
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B_4bit
[
col
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
for
(
int
col
=
0
;
col
<
8
;
col
++
)
loaded_values
=
1
;
}
else
{
if
((
col
+
col_offset
)
>=
M
){
break
;
}
local_A
[
0
]
=
local_A
[
1
];
loaded_values
--
;
int
offset_B
=
(
col_offset
+
col
)
*
ldb
;
// 0111 -> 0.0f in NF4
// since we have packed 8-bits, we need cat(0b0111, 0b0111) = 0b01110111
vector_load
<
unsigned
char
,
int4
,
16
>
(
local_B_4bit
,
B
,
(
offset_B
+
idx
+
1
)
/
2
,
(
idx
+
1
)
/
2
,
(
K
+
1
)
/
2
,
0b01110111
);
#pragma unroll 64
for
(
int
col
=
0
;
col
<
64
;
col
+=
2
)
{
local_B
[
col
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
>>
4
)
*
T
(
1.0
f
);
local_B
[
col
+
1
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
*
T
(
1.0
f
);
}
}
int
absidx
=
(
idx
+
offset_B
)
/
blocksize
;
half
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
//for(int k = 0; k < 16; k++)
//printf("%i %i ", local_B_4bit[k] >> 4, local_B_4bit[k] & 0x0F);
//printf("\n");
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
//vector_load<T, int4, 8>(local_A, A, idx, idx, K);
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
local_B
[
col
];
}
else
if
(
warp_id
<
(
WARPS
-
1
))
{
local_A
[
0
]
=
T
(
0.0
);
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
0.0
f
;
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
0.0
f
;
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
0.0
f
;
}
ticktock
=
ticktock
==
0
?
1
:
0
;
#pragma unroll 16
for
(
int
k
=
0
;
k
<
16
;
k
++
)
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for
(
int
base_idx
=
blockDim
.
x
-
32
;
base_idx
<
K
;
base_idx
+=
blockDim
.
x
-
32
)
{
idx
=
base_idx
+
threadIdx
.
x
;
__syncthreads
();
if
(
idx
<
K
&&
warp_id
<
(
WARPS
-
1
))
{
if
(
loaded_values
==
0
)
{
local_A
[
0
]
=
A
[
idx
];
local_A
[
1
]
=
A
[
idx
+
blockDim
.
x
-
32
];
//if(local_B_4bit[k ] != 0b01110111)
//printf("(%i %i %i) %i -> %f, %i -> %f\n", threadIdx.x , k, K, local_B_4bit[k ] >> 4, dDequantizeNF4(local_B_4bit[k ] >> 4, local_absmax),
//local_B_4bit[k ] & 0x0F, dDequantizeNF4(local_B_4bit[k ] & 0x0F, local_absmax));
//local_B[k*2] = d2DequantizeFP4(local_B_4bit[k] >> 4);//*local_absmax;
//local_B[k*2 + 1] = d2DequantizeFP4(local_B_4bit[k] & 0x0F);//*local_absmax;
local_B
[
k
*
2
]
=
(
half
)(
local_B_4bit
[
k
]
>>
4
)
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
(
half
)(
local_B_4bit
[
k
]
&
0x0F
)
*
local_absmax
;
//local_B[k*2] = (half)dDequantizeNF4(local_B_4bit[k ] >> 4);//*local_absmax;
//local_B[k*2 + 1] = (half)dDequantizeNF4(local_B_4bit[k ] & 0x0F);//*local_absmax;
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
{
local_B_4bit
[
col
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
local_B_4bit
[
col
+
16
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
}
loaded_values
=
1
;
}
else
{
local_A
[
0
]
=
local_A
[
1
];
loaded_values
--
;
int
absidx
=
(
idx
+
col_offset
)
/
blocksize
;
half
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
#pragma unroll 64
for
(
int
col
=
0
;
col
<
64
;
col
+=
2
)
{
local_B
[
col
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
>>
4
)
*
T
(
absidx
);
local_B
[
col
+
1
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
*
T
(
absidx
);
}
}
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
#pragma unroll 32
//for(int k = 0; k < 8; k++)
for
(
int
k
=
0
;
k
<
32
;
k
++
)
for
(
int
col
=
0
;
col
<
32
;
col
++
)
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
local_B
[
col
];
}
else
if
(
warp_id
<
(
WARPS
-
1
))
{
local_A
[
0
]
=
T
(
0.0
);
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
0.0
f
;
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
0.0
f
;
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
0.0
f
;
}
ticktock
=
ticktock
==
0
?
1
:
0
;
if
(
warp_id
==
(
WARPS
-
1
))
for
(
int
k
=
0
;
k
<
batch_size_warps
;
k
++
)
{
local_C
[
col
]
+=
local_A
[
k
]
*
local_B
[
k
];
//if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
//if((float)local_B[k] != 0.0)
//printf("%i %i %i %i %f*%f\n", threadIdx.x, k, col, (float)local_A[k], (float)local_B[k]);
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[(
ticktock
*
batch_size_warps
+
k
)
*
a_tile_offset
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[(
ticktock
*
batch_size_warps
+
k
)
*
b_tile_offset
]),
16
);
// 35 mu
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
);
}
}
}
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
{
local_C
[
k
]
=
BlockReduce
(
reduce
).
Reduce
(
local_C
[
k
],
cub
::
Sum
());
__syncthreads
();
}
__syncthreads
();
if
(
warp_id
!=
(
WARPS
-
1
)){
return
;
}
// only warp_id == (WARPS-1) from here
int
warp_lane
=
threadIdx
.
x
%
32
;
if
(
threadIdx
.
x
==
0
)
ticktock
=
ticktock
==
0
?
1
:
0
;
for
(
int
k
=
0
;
k
<
batch_size_warps
;
k
++
)
{
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
smem_C
[
k
]
=
local_C
[
k
]
;
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[(
ticktock
*
batch_size_warps
+
k
)
*
a_tile_offset
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[(
ticktock
*
batch_size_warps
+
k
)
*
b_tile_offset
]),
16
);
// 35 mu
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
)
;
}
else
if
(
threadIdx
.
x
>=
32
)
// early return for unused warps
return
;
__syncwarp
();
// 129 mu
if
(
warp_id
==
(
WARPS
-
1
))
wmma
::
store_matrix_sync
(
&
(
smem_A
[
0
]),
c_frag
,
32
,
wmma
::
mem_row_major
);
if
(
threadIdx
.
x
<
8
&&
col_offset
+
threadIdx
.
x
<
M
)
out
[
col_offset
+
threadIdx
.
x
]
=
smem_C
[
threadIdx
.
x
];
if
(
col_offset
+
warp_lane
<
M
)
out
[
col_offset
+
warp_lane
]
=
smem_A
[
warp_lane
];
}
//#define ROWS 2
...
...
@@ -3513,6 +3609,7 @@ template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * _
template
__global__
void
gemm_device
<
half
,
16
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
kgemm_4bit_inference
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
...
...
csrc/ops.cu
View file @
264a9485
...
...
@@ -703,17 +703,17 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
int
num_blocks
=
(
m
+
7
)
/
8
;
int
num_blocks
=
(
m
+
31
)
/
32
;
cout
<<
num_blocks
<<
endl
;
cout
<<
lda
<<
endl
;
cout
<<
ldb
<<
endl
;
cout
<<
ldc
<<
endl
;
//
cout << num_blocks << endl;
//
cout << lda << endl;
//
cout << ldb << endl;
//
cout << ldc << endl;
cout
<<
m
<<
endl
;
cout
<<
n
<<
endl
;
cout
<<
k
<<
endl
;
kgemm_4bit_inference
<
T
,
1
28
><<<
num_blocks
,
1
28
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
//
cout << m << endl;
//
cout << n << endl;
//
cout << k << endl;
kgemm_4bit_inference
<
T
,
1
60
><<<
num_blocks
,
1
60
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
...
...
tests/test_functional.py
View file @
264a9485
...
...
@@ -2358,20 +2358,19 @@ def test_normal_map_tree():
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_cutlass3_gemm
(
dtype
):
for
dim
in
[
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]:
debug
=
True
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [4096]:
for
dim
in
[
4096
]:
#for dim in [128+1]:
errs
=
[]
relerrs
=
[]
max_err
=
0
max_relerr
=
0
for
i
in
range
(
100
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A
=
torch
.
randn
(
1
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print(A)
...
...
@@ -2397,7 +2396,7 @@ def test_cutlass3_gemm(dtype):
errs
.
append
(
err
)
relerrs
.
append
(
relerr
)
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
#if
not debug and
err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('')
# print(i, err, relerr)
# print(A.flatten()[-6:])
...
...
@@ -2412,7 +2411,7 @@ def test_cutlass3_gemm(dtype):
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
True
)
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
not
debug
)
#print(c/math.sqrt(dim))
print
(
''
)
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
...
...
@@ -2422,29 +2421,73 @@ def test_cutlass3_gemm(dtype):
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_gemm_4bit
(
dtype
):
for
i
in
range
(
1
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#torch.random.manual_seed(17)
A
=
torch
.
rand
(
1
,
4096
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
rand
(
4
*
4096
,
4096
,
dtype
=
dtype
,
device
=
'cuda'
)
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
for
dim
in
[
4096
]:
errs
=
[]
relerrs
=
[]
max_err
=
0
max_relerr
=
0
for
i
in
range
(
1
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A
=
torch
.
randn
(
1
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#print('')
#print(A)
#print(B.t())
#A[:, :-1] = 0
#B[:, :-1] = 0
#print('')
#print(A)
#print(B)
qB
,
state
=
F
.
quantize_nf4
(
B
)
F
.
dequantize_nf4
(
qB
,
state
)
qB
,
state
=
F
.
quantize_nf4
(
B
)
F
.
dequantize_nf4
(
qB
,
state
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
print
(
C1
.
shape
,
C2
.
shape
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
#C1 = bnb.matmul_4bit(A, qB.t(), state)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
#print(C1)
#print(C2)
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err
=
torch
.
abs
(
C1
-
C2
)
mag
=
torch
.
abs
(
C1
)
+
1e-8
relerr
=
err
/
mag
max_err
=
max
(
err
.
max
(),
max_err
)
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
err
=
err
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
errs
.
append
(
err
)
relerrs
.
append
(
relerr
)
if
err
/
torch
.
abs
(
C1
).
mean
()
>
5e-5
or
err
>
3.2e-5
:
print
(
''
)
print
(
i
,
err
,
relerr
)
print
(
A
.
flatten
()[
-
6
:])
print
(
B
.
flatten
()[
-
6
:])
out
=
A
.
flatten
()[
-
6
:]
*
B
.
flatten
()[
-
6
:]
print
(
out
)
print
(
out
[:
-
1
].
sum
())
print
(
'='
*
80
)
print
(
C1
.
flatten
()[
-
6
:])
print
(
C2
.
flatten
()[
-
6
:])
#assert False, 'ERROR'
#torch.testing.assert_close(C1, C2, atol=1e-5, rtol=0.005)
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
#print(c/math.sqrt(dim))
print
(
''
)
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
(
max_err
.
item
(),
max_relerr
.
item
()))
def
test_pipeline_func
():
a
=
torch
.
rand
(
2
,
4
).
cuda
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment