Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
0afc8e9e
You need to sign in or sign up before continuing.
Commit
0afc8e9e
authored
Apr 26, 2023
by
Tim Dettmers
Browse files
Best attempt at cutlass3.
parent
84964db9
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
282 additions
and
102 deletions
+282
-102
Makefile
Makefile
+4
-4
bitsandbytes/functional.py
bitsandbytes/functional.py
+98
-0
csrc/kernels.cu
csrc/kernels.cu
+77
-51
csrc/kernels.cuh
csrc/kernels.cuh
+21
-1
csrc/ops.cu
csrc/ops.cu
+31
-46
csrc/ops.cuh
csrc/ops.cuh
+12
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+18
-0
tests/test_functional.py
tests/test_functional.py
+21
-0
No files found.
Makefile
View file @
0afc8e9e
...
@@ -55,8 +55,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
...
@@ -55,8 +55,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
CC_cublasLt110
+=
-gencode
arch
=
compute_80,code
=
sm_80
CC_cublasLt110
+=
-gencode
arch
=
compute_80,code
=
sm_80
CC_cublasLt111
:=
-gencode
arch
=
compute_75,code
=
sm_75
CC_cublasLt111
:=
-gencode
arch
=
compute_75,code
=
sm_75
CC_cublasLt111
+=
-gencode
arch
=
compute_80,code
=
sm_80
#
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111
+=
-gencode
arch
=
compute_86,code
=
sm_86
#
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
CC_ADA_HOPPER
:=
-gencode
arch
=
compute_89,code
=
sm_89
CC_ADA_HOPPER
:=
-gencode
arch
=
compute_89,code
=
sm_89
CC_ADA_HOPPER
+=
-gencode
arch
=
compute_90,code
=
sm_90
CC_ADA_HOPPER
+=
-gencode
arch
=
compute_90,code
=
sm_90
...
@@ -103,9 +103,9 @@ cuda11x: $(BUILD_DIR) env
...
@@ -103,9 +103,9 @@ cuda11x: $(BUILD_DIR) env
$(GPP)
-std
=
c++14
-DBUILD_CUDA
-shared
-fPIC
$(INCLUDE)
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
$(BUILD_DIR)
/link.o
$(FILES_CPP)
-o
./bitsandbytes/libbitsandbytes_cuda
$(CUDA_VERSION)
.so
$(LIB)
$(GPP)
-std
=
c++14
-DBUILD_CUDA
-shared
-fPIC
$(INCLUDE)
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
$(BUILD_DIR)
/link.o
$(FILES_CPP)
-o
./bitsandbytes/libbitsandbytes_cuda
$(CUDA_VERSION)
.so
$(LIB)
cuda11x_cutlass
:
$(BUILD_DIR) env cutlass
cuda11x_cutlass
:
$(BUILD_DIR) env cutlass
$(NVCC)
$(CC_cublasLt111)
-Xcompiler
'-fPIC'
--use_fast_math
-Xptxas
=
-v
-dc
$(FILES_CUDA)
$(INCLUDE)
$(INCLUDE_cutlass)
$(LIB)
--output-directory
$(BUILD_DIR)
$(NVCC)
$(CC_cublasLt111)
-Xcompiler
'-fPIC'
--use_fast_math
--expt-relaxed-constexpr
-Xptxas
=
-v
-dc
$(FILES_CUDA)
$(INCLUDE)
$(INCLUDE_cutlass)
$(LIB)
--output-directory
$(BUILD_DIR)
$(NVCC)
$(CC_cublasLt111)
-Xcompiler
'-fPIC'
-dlink
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
-o
$(BUILD_DIR)
/link.o
$(NVCC)
$(CC_cublasLt111)
-Xcompiler
'-fPIC'
-dlink
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
-o
$(BUILD_DIR)
/link.o
$(GPP)
-std
=
c++17
-DBUILD_CUDA
-shared
-fPIC
$(INCLUDE)
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
$(BUILD_DIR)
/link.o
$(FILES_CPP)
-o
./bitsandbytes/libbitsandbytes_cuda
$(CUDA_VERSION)
.so
$(LIB)
$(GPP)
-std
=
c++17
-DBUILD_CUDA
-shared
-fPIC
$(INCLUDE)
$(INCLUDE_cutlass)
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
$(BUILD_DIR)
/link.o
$(FILES_CPP)
-o
./bitsandbytes/libbitsandbytes_cuda
$(CUDA_VERSION)
.so
$(LIB)
cuda12x
:
$(BUILD_DIR) env
cuda12x
:
$(BUILD_DIR) env
$(NVCC)
$(CC_cublasLt111)
$(CC_ADA_HOPPER)
-Xcompiler
'-fPIC'
--use_fast_math
-Xptxas
=
-v
-dc
$(FILES_CUDA)
$(INCLUDE)
$(LIB)
--output-directory
$(BUILD_DIR)
$(NVCC)
$(CC_cublasLt111)
$(CC_ADA_HOPPER)
-Xcompiler
'-fPIC'
--use_fast_math
-Xptxas
=
-v
-dc
$(FILES_CUDA)
$(INCLUDE)
$(LIB)
--output-directory
$(BUILD_DIR)
...
...
bitsandbytes/functional.py
View file @
0afc8e9e
...
@@ -1374,6 +1374,104 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
...
@@ -1374,6 +1374,104 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
return
sout
return
sout
def
cutlass3_gemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
,
):
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
float32
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
sA
=
A
.
shape
sB
=
B
.
shape
if
transposed_A
and
len
(
sA
)
==
2
:
sA
=
(
sA
[
1
],
sA
[
0
])
elif
transposed_A
and
len
(
sA
)
==
3
:
sA
=
(
sA
[
0
],
sA
[
2
],
sA
[
0
])
if
transposed_B
and
len
(
sB
)
==
2
:
sB
=
(
sB
[
1
],
sB
[
0
])
elif
transposed_B
and
len
(
sB
)
==
3
:
sB
=
(
sB
[
0
],
sB
[
2
],
sB
[
0
])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if
len
(
sB
)
==
2
:
if
B
.
stride
()[
0
]
==
B
.
shape
[
1
]:
transposed_B
=
False
elif
B
.
stride
()[
1
]
==
B
.
shape
[
0
]:
transposed_B
=
True
if
len
(
A
.
shape
)
==
2
:
if
A
.
stride
()[
0
]
==
A
.
shape
[
1
]:
transposed_A
=
False
elif
A
.
stride
()[
1
]
==
A
.
shape
[
0
]:
transposed_A
=
True
else
:
if
A
.
stride
()[
1
]
==
A
.
shape
[
2
]:
transposed_A
=
False
elif
A
.
stride
()[
2
]
==
A
.
shape
[
1
]:
transposed_A
=
True
if
len
(
sA
)
==
2
:
n
=
sA
[
0
]
ldb
=
A
.
stride
()[
1
if
transposed_A
else
0
]
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
n
=
sA
[
0
]
*
sA
[
1
]
ldb
=
sA
[
2
]
m
=
sB
[
1
]
k
=
sB
[
0
]
lda
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
ldc
=
sB
[
1
]
elif
len
(
sB
)
==
3
:
# special case
assert
len
(
sA
)
==
3
if
not
(
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]):
raise
ValueError
(
f
"Only bsi,bso->io supported for tensor contractions, but dims for A x B were:
{
sA
}
x
{
sB
}
"
)
transposed_A
=
True
transposed_B
=
False
m
=
sB
[
2
]
n
=
sA
[
2
]
k
=
sB
[
0
]
*
sB
[
1
]
lda
=
m
ldb
=
sA
[
2
]
ldc
=
m
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
# B^T @ A^T = C^T
# [km, nk -> mn]
lda
=
ldb
=
ldc
=
1
#lda = 1
print
(
m
,
n
,
k
,
lda
,
ldb
,
ldc
)
is_on_gpu
([
B
,
A
,
out
])
m
=
ct
.
c_int32
(
m
)
n
=
ct
.
c_int32
(
n
)
k
=
ct
.
c_int32
(
k
)
lda
=
ct
.
c_int32
(
lda
)
ldb
=
ct
.
c_int32
(
ldb
)
ldc
=
ct
.
c_int32
(
ldc
)
alpha
=
ct
.
c_float
(
1.0
)
beta
=
ct
.
c_float
(
0.0
)
lib
.
ccutlass_gemm
(
m
,
n
,
k
,
alpha
,
get_ptr
(
B
),
lda
,
get_ptr
(
A
),
ldb
,
beta
,
get_ptr
(
out
),
ldc
)
return
out
def
igemm
(
def
igemm
(
A
:
Tensor
,
A
:
Tensor
,
...
...
csrc/kernels.cu
View file @
0afc8e9e
...
@@ -19,7 +19,6 @@
...
@@ -19,7 +19,6 @@
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/cublas_wrappers.hpp"
#include "cutlass/util/cublas_wrappers.hpp"
#include "cutlass/util/helper_cuda.hpp"
#define HLF_MAX 65504
#define HLF_MAX 65504
#define TH 1024
#define TH 1024
...
@@ -2928,73 +2927,84 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
...
@@ -2928,73 +2927,84 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
}
}
template
<
int
QUANT_TYPE
,
typename
INPT
,
typename
COMPT
,
typename
OUTT
>
__global__
void
kMatmul_inference_4bit
(
INPT
*
A
,
unsigned
char
*
B
,
OUTT
*
out
,
int
lda
,
int
ldb
,
int
rowsA
,
int
colsA
,
int
colsB
)
//
template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB)
{
//
{
// element-wise kernel
//
// element-wise kernel
// 1. Load batch x k into registers
//
// 1. Load batch x k into registers
// 2. Load k x k into registers
//
// 2. Load k x k into registers
// 3. dequantize and store in second pair of k x k
//
// 3. dequantize and store in second pair of k x k
// 4. matmul
//
// 4. matmul
// 5. sum with cub
//
// 5. sum with cub
// 6. store outputs
//
// 6. store outputs
// TC kernel
//
// TC kernel
// use k warps per thread block
//
// use k warps per thread block
// 1. threadblock use read-only cache to read in register tile for A into shared memory
//
// 1. threadblock use read-only cache to read in register tile for A into shared memory
// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
//
// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
// 3. each warp reads a segment of values 16x32 from B
//
// 3. each warp reads a segment of values 16x32 from B
// 4. do dequantization from register of B into second pair of registers
//
// 4. do dequantization from register of B into second pair of registers
// 5. store (4) into fragment
//
// 5. store (4) into fragment
// 6. matmul aggregate into fragment C
//
// 6. matmul aggregate into fragment C
// 7. aggreecate files of C into shared memroy block C
//
// 7. aggreecate files of C into shared memroy block C
// 8. sum (7)
//
// 8. sum (7)
// 9. write outputs to matmul output matrix
//
// 9. write outputs to matmul output matrix
}
//
}
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
# include "cutlass/util/cublas_wrappers.hpp"
# include "cutlass/util/cublas_wrappers.hpp"
#endif
#endif
#include "cutlass/util/helper_cuda.hpp"
//#include "cutlass/util/helper_cuda.hpp"
template
<
class
MShape
,
class
NShape
,
class
KShape
,
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
class
TA
,
class
AStride
,
class
ABlockLayout
,
class
AThreadLayout
,
float
const
*
A
,
class
TB
,
class
BStride
,
class
BBlockLayout
,
class
BThreadLayout
,
float
const
*
B
,
class
TC
,
class
CStride
,
class
CBlockLayout
,
class
CThreadLayout
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
class
Alpha
,
class
Beta
>
float
alpha
,
float
beta
)
__global__
static
__launch_bounds__
(
decltype
(
size
(
CThreadLayout
{}))
::
value
)
void
gemm_device
(
MShape
M
,
NShape
N
,
KShape
K
,
TA
const
*
A
,
AStride
dA
,
ABlockLayout
blockA
,
AThreadLayout
tA
,
TB
const
*
B
,
BStride
dB
,
BBlockLayout
blockB
,
BThreadLayout
tB
,
TC
*
out
,
CStride
dC
,
CBlockLayout
,
CThreadLayout
tC
,
Alpha
alpha
,
Beta
beta
)
{
{
using
namespace
cute
;
using
namespace
cute
;
using
X
=
Underscore
;
using
X
=
Underscore
;
// Preconditions
// Preconditions
CUTE_STATIC_ASSERT
(
is_static
<
ABlockLayout
>::
value
);
//CUTE_STATIC_ASSERT(is_static<ABlockLayout>::value);
CUTE_STATIC_ASSERT
(
is_static
<
BBlockLayout
>::
value
);
//CUTE_STATIC_ASSERT(is_static<BBlockLayout>::value);
CUTE_STATIC_ASSERT
(
is_static
<
CBlockLayout
>::
value
);
//CUTE_STATIC_ASSERT(is_static<CBlockLayout>::value);
//CUTE_STATIC_ASSERT(is_static<AThreadLayout>::value);
//CUTE_STATIC_ASSERT(is_static<BThreadLayout>::value);
//CUTE_STATIC_ASSERT(is_static<CThreadLayout>::value);
CUTE_STATIC_ASSERT
(
is_static
<
AThreadLayout
>::
value
);
//CUTE_STATIC_ASSERT_V(size(tA) == size(tC));
CUTE_STATIC_ASSERT
(
is_static
<
BThreadLayout
>::
value
);
//CUTE_STATIC_ASSERT_V(size(tB) == size(tC));
CUTE_STATIC_ASSERT
(
is_static
<
CThreadLayout
>::
value
);
CUTE_STATIC_ASSERT_V
(
size
(
tA
)
==
size
(
tC
));
// Define block sizes (static)
CUTE_STATIC_ASSERT_V
(
size
(
tB
)
==
size
(
tC
));
auto
bM
=
Int
<
128
>
{};
auto
bN
=
Int
<
128
>
{};
auto
bK
=
Int
<
8
>
{};
// Define the block layouts (static)
auto
bA
=
make_layout
(
make_shape
(
bM
,
bK
));
auto
bB
=
make_layout
(
make_shape
(
bN
,
bK
));
auto
bC
=
make_layout
(
make_shape
(
bM
,
bN
));
// Define the thread layouts (static)
auto
tA
=
make_layout
(
make_shape
(
Int
<
32
>
{},
Int
<
8
>
{}));
auto
tB
=
make_layout
(
make_shape
(
Int
<
32
>
{},
Int
<
8
>
{}));
auto
tC
=
make_layout
(
make_shape
(
Int
<
16
>
{},
Int
<
16
>
{}));
//CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M
//CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M
//CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N
//CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N
CUTE_STATIC_ASSERT_V
(
shape
<
1
>
(
blockA
)
==
shape
<
1
>
(
blockB
));
// BLK_K
//
CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K
// Shared memory buffers
// Shared memory buffers
__shared__
TA
smemA
[
cosize_v
<
ABlockLayout
>
];
__shared__
float
smemA
[
128
*
8
];
__shared__
TB
smemB
[
cosize_v
<
BBlockLayout
>
];
__shared__
float
smemB
[
128
*
8
];
auto
sA
=
make_tensor
(
make_smem_ptr
(
smemA
),
blockA
);
// (BLK_M,BLK_K)
auto
sA
=
make_tensor
(
make_smem_ptr
(
smemA
),
bA
);
// (BLK_M,BLK_K)
auto
sB
=
make_tensor
(
make_smem_ptr
(
smemB
),
blockB
);
// (BLK_N,BLK_K)
auto
sB
=
make_tensor
(
make_smem_ptr
(
smemB
),
bB
);
// (BLK_N,BLK_K)
auto
dA
=
make_stride
(
Int
<
1
>
{},
lda
);
auto
dB
=
make_stride
(
Int
<
1
>
{},
ldb
);
auto
dC
=
make_stride
(
Int
<
1
>
{},
ldc
);
// Represent the full tensors
// Represent the full tensors
auto
mA
=
make_tensor
(
make_gmem_ptr
(
A
),
make_shape
(
M
,
K
),
dA
);
// (M,K)
auto
mA
=
make_tensor
(
make_gmem_ptr
(
A
),
make_shape
(
M
,
K
),
dA
);
// (M,K)
...
@@ -3083,11 +3093,27 @@ gemm_device(MShape M, NShape N, KShape K,
...
@@ -3083,11 +3093,27 @@ gemm_device(MShape M, NShape N, KShape K,
}
}
//==============================================================
//==============================================================
// TEMPLATE DEFINITIONS
// TEMPLATE DEFINITIONS
//==============================================================
//==============================================================
template
__global__
void
kMatmul_inference_4bit
<
NF4
,
half
,
half
,
half
>(
half
*
A
,
unsigned
char
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
rowsA
,
int
colsA
,
int
colsB
);
//template <class MShape, class NShape, class KShape,
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
// class Alpha, class Beta>
//__global__ static
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
//void
//gemm_device(MShape M, NShape N, KShape K,
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// half alpha, half beta);
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
...
...
csrc/kernels.cuh
View file @
0afc8e9e
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#ifndef kernels
#ifndef kernels
#define kernels
#define kernels
template
<
int
QUANT_TYPE
,
typename
INP_TYPE
,
typename
COMP_TYPE
,
typename
OUT_TYPE
>
__global__
void
kMatmul_inference_4bit
(
INP_TYPE
*
A
,
unsigned
char
*
B
,
OUT_TYPE
*
out
,
int
lda
,
int
ldb
,
int
rowsA
,
int
colsA
,
int
colsB
);
//
template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB);
template
<
typename
T
>
__global__
void
kEstimateQuantiles
(
T
*
__restrict__
const
A
,
float
*
code
,
const
float
offset
,
const
T
max_val
,
const
int
n
);
template
<
typename
T
>
__global__
void
kEstimateQuantiles
(
T
*
__restrict__
const
A
,
float
*
code
,
const
float
offset
,
const
T
max_val
,
const
int
n
);
...
@@ -122,4 +122,24 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
...
@@ -122,4 +122,24 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
template
<
int
FORMAT
>
__global__
void
kExtractOutliers
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
<
int
FORMAT
>
__global__
void
kExtractOutliers
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
//template <class MShape, class NShape, class KShape,
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
// class Alpha, class Beta>
//__global__ static
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
//void
//gemm_device(MShape M, NShape N, KShape K,
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// Alpha alpha, Beta beta);
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
const
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
float
alpha
,
float
beta
);
#endif
#endif
csrc/ops.cu
View file @
0afc8e9e
...
@@ -91,14 +91,12 @@ template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsign
...
@@ -91,14 +91,12 @@ template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsign
}
}
void
matmul4bite
(
half
*
A
,
unsigned
char
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
rowsA
,
int
colsA
,
int
colsB
)
//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB)
{
//{
int
num_blocks
=
(
colsB
+
32
-
1
)
/
32
;
// int num_blocks = (colsB+32-1)/32;
kMatmul_inference_4bit
<
NF4
,
half
,
half
,
half
><<<
num_blocks
,
256
>>>
(
A
,
B
,
out
,
lda
,
ldb
,
rowsA
,
colsA
,
colsB
);
// kMatmul_inference_4bit<NF4, half, half, half><<<num_blocks, 256>>>(A, B, out, lda, ldb, rowsA, colsA, colsB);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
// CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
//}
template
<
int
QUANT_TYPE
,
typename
INP_TYPE
,
typename
COMP_TYPE
,
typename
OUT_TYPE
>
__global__
void
kMatmul_inference_4bit
(
INP_TYPE
*
A
,
unsigned
char
*
B
,
OUT_TYPE
*
C
,
int
lda
,
int
ldb
,
int
rowsA
,
int
colsA
,
int
colsB
);
template
<
typename
T
,
int
OPTIMIZER
>
void
optimizer32bit
(
T
*
g
,
T
*
p
,
template
<
typename
T
,
int
OPTIMIZER
>
void
optimizer32bit
(
T
*
g
,
T
*
p
,
...
@@ -666,60 +664,47 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
...
@@ -666,60 +664,47 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
#include <cute/tensor.hpp>
#include <cute/tensor.hpp>
#include "cutlass/util/helper_cuda.hpp"
template
<
typename
TA
,
typename
TB
,
typename
TC
,
void
gemm_host
(
int
m
,
int
n
,
int
k
,
typename
Alpha
,
typename
Beta
>
float
alpha
,
void
float
const
*
A
,
int
lda
,
gemm
(
int
m
,
int
n
,
int
k
,
float
const
*
B
,
int
ldb
,
Alpha
alpha
,
float
beta
,
TA
const
*
A
,
int
ldA
,
float
*
C
,
int
ldc
)
TB
const
*
B
,
int
ldB
,
Beta
beta
,
TC
*
C
,
int
ldC
,
cudaStream_t
stream
=
0
)
{
{
cute
::
device_init
(
0
);
using
namespace
cute
;
using
namespace
cute
;
// Define shapes (dynamic)
// Define shapes (dynamic)
auto
M
=
int
(
m
);
auto
M
=
int
(
m
);
auto
N
=
int
(
n
);
auto
N
=
int
(
n
);
auto
K
=
int
(
k
);
auto
K
=
int
(
k
);
// Define strides (mixed)
auto
dA
=
make_stride
(
Int
<
1
>
{},
ldA
);
printf
(
"%i %i %i %i %i %i
\n
"
,
m
,
n
,
k
,
lda
,
ldb
,
ldc
);
auto
dB
=
make_stride
(
Int
<
1
>
{},
ldB
);
auto
dC
=
make_stride
(
Int
<
1
>
{},
ldC
);
dim3
dimBlock
(
16
,
16
);
dim3
dimGrid
((
M
+
127
)
/
128
,
(
N
+
127
)
/
128
);
// Define block sizes (static)
// auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
auto
bM
=
Int
<
128
>
{};
//-
auto
bN
=
Int
<
128
>
{};
//- dim3 dimBlock(size(tC));
auto
bK
=
Int
<
8
>
{};
//- dim3 dimGrid(ceil_div(size(M), size(bM)),
//- ceil_div(size(N), size(bN)));
// Define the block layouts (static)
auto
sA
=
make_layout
(
make_shape
(
bM
,
bK
));
auto
sB
=
make_layout
(
make_shape
(
bN
,
bK
));
auto
sC
=
make_layout
(
make_shape
(
bM
,
bN
));
// Define the thread layouts (static)
auto
tA
=
make_layout
(
make_shape
(
Int
<
32
>
{},
Int
<
8
>
{}));
auto
tB
=
make_layout
(
make_shape
(
Int
<
32
>
{},
Int
<
8
>
{}));
auto
tC
=
make_layout
(
make_shape
(
Int
<
16
>
{},
Int
<
16
>
{}));
dim3
dimBlock
(
size
(
tC
));
dim3
dimGrid
(
ceil_div
(
size
(
M
),
size
(
bM
)),
ceil_div
(
size
(
N
),
size
(
bN
)));
gemm_device
gemm_device
<<<
dimGrid
,
dimBlock
,
0
,
stream
>>>
<<<
dimGrid
,
dimBlock
,
0
,
0
>>>
(
M
,
N
,
K
,
(
M
,
N
,
K
,
A
,
dA
,
sA
,
tA
,
A
,
B
,
dB
,
sB
,
tB
,
B
,
C
,
dC
,
sC
,
tC
,
C
,
lda
,
ldb
,
ldc
,
alpha
,
beta
);
alpha
,
beta
);
}
}
//==============================================================
//==============================================================
// TEMPLATE DEFINITIONS
// TEMPLATE DEFINITIONS
//==============================================================
//==============================================================
...
...
csrc/ops.cuh
View file @
0afc8e9e
...
@@ -20,6 +20,11 @@
...
@@ -20,6 +20,11 @@
#include <vector>
#include <vector>
#include <functional>
#include <functional>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#define CUDA_CHECK_RETURN(value) { \
#define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \
cudaError_t _m_cudaStat = value; \
if (_m_cudaStat != cudaSuccess) { \
if (_m_cudaStat != cudaSuccess) { \
...
@@ -185,4 +190,11 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
...
@@ -185,4 +190,11 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
void
matmul4bite
(
half
*
A
,
unsigned
char
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
rowsA
,
int
colsA
,
int
colsB
);
void
matmul4bite
(
half
*
A
,
unsigned
char
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
rowsA
,
int
colsA
,
int
colsB
);
void
gemm_host
(
int
m
,
int
n
,
int
k
,
float
alpha
,
float
const
*
A
,
int
ldA
,
float
const
*
B
,
int
ldB
,
float
beta
,
float
*
C
,
int
ldC
);
#endif
#endif
csrc/pythonInterface.c
View file @
0afc8e9e
...
@@ -20,6 +20,16 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat
...
@@ -20,6 +20,16 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat
void
estimateQuantiles_fp16
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
){
estimateQuantiles
<
half
>
(
A
,
code
,
offset
,
n
);
}
void
estimateQuantiles_fp16
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
){
estimateQuantiles
<
half
>
(
A
,
code
,
offset
,
n
);
}
void
cppgemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
float
const
*
A
,
int
ldA
,
float
const
*
B
,
int
ldB
,
float
beta
,
float
*
C
,
int
ldC
)
{
gemm_host
(
m
,
n
,
k
,
alpha
,
A
,
ldA
,
B
,
ldB
,
beta
,
C
,
ldC
);}
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
void fname##32bit_g##gbits(gtype *g, gtype *p, \
void fname##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
...
@@ -306,6 +316,14 @@ extern "C"
...
@@ -306,6 +316,14 @@ extern "C"
void
cextractOutliers_turing
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
){
extractOutliers_turing
(
A
,
idx
,
out
,
idx_size
,
rows
,
cols
);
}
void
cextractOutliers_turing
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
){
extractOutliers_turing
(
A
,
idx
,
out
,
idx_size
,
rows
,
cols
);
}
void
cextractOutliers_ampere
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
){
extractOutliers_ampere
(
A
,
idx
,
out
,
idx_size
,
rows
,
cols
);
}
void
cextractOutliers_ampere
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
){
extractOutliers_ampere
(
A
,
idx
,
out
,
idx_size
,
rows
,
cols
);
}
void
ccutlass_gemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
float
const
*
A
,
int
ldA
,
float
const
*
B
,
int
ldB
,
float
beta
,
float
*
C
,
int
ldC
)
{
cppgemm
(
m
,
n
,
k
,
alpha
,
A
,
ldA
,
B
,
ldB
,
beta
,
C
,
ldC
);}
#endif
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
...
...
tests/test_functional.py
View file @
0afc8e9e
...
@@ -2351,3 +2351,24 @@ def test_normal_map_tree():
...
@@ -2351,3 +2351,24 @@ def test_normal_map_tree():
pivots
.
append
((
values
[
i
-
1
]
+
values
[
i
])
/
2
)
pivots
.
append
((
values
[
i
-
1
]
+
values
[
i
])
/
2
)
print
(
pivots
)
print
(
pivots
)
def
test_cutlass3_gemm
():
#A = torch.rand(2, 2).cuda()
#B = torch.rand(2, 2).cuda()
A
=
torch
.
arange
(
4
).
reshape
(
2
,
2
).
float
().
cuda
().
contiguous
()
B
=
torch
.
ones
(
2
,
2
).
float
().
cuda
()
print
(
''
)
print
(
A
)
print
(
B
)
C1
=
torch
.
matmul
(
A
,
B
)
print
(
C1
)
C2
=
F
.
cutlass3_gemm
(
A
,
B
.
t
())
print
(
C2
)
C2
=
F
.
cutlass3_gemm
(
A
,
B
)
print
(
C2
)
C2
=
F
.
cutlass3_gemm
(
B
.
t
(),
A
.
t
().
contiguous
())
print
(
C2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment