Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
84964db9
Commit
84964db9
authored
Apr 25, 2023
by
Tim Dettmers
Browse files
CUTLASS compiles.
parent
6e2544da
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
20 additions
and
14 deletions
+20
-14
Makefile
Makefile
+4
-3
bitsandbytes/functional.py
bitsandbytes/functional.py
+2
-2
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+1
-0
csrc/kernels.cu
csrc/kernels.cu
+12
-6
csrc/ops.cu
csrc/ops.cu
+1
-3
No files found.
Makefile
View file @
84964db9
MKFILE_PATH
:=
$(
abspath
$(
lastword
$(MAKEFILE_LIST)
))
MKFILE_PATH
:=
$(
abspath
$(
lastword
$(MAKEFILE_LIST)
))
ROOT_DIR
:=
$(
patsubst
%/,%,
$(
dir
$(MKFILE_PATH)
))
ROOT_DIR
:=
$(
patsubst
%/,%,
$(
dir
$(MKFILE_PATH)
))
GPP
:=
/usr/bin/g++
#GPP:= /usr/bin/g++
GPP
:=
/sw/gcc/11.2.0/bin/g++
ifeq
($(CUDA_HOME),)
ifeq
($(CUDA_HOME),)
CUDA_HOME
:=
$(
shell
which nvcc | rev |
cut
-d
'/'
-f3-
| rev
)
CUDA_HOME
:=
$(
shell
which nvcc | rev |
cut
-d
'/'
-f3-
| rev
)
endif
endif
...
@@ -25,7 +26,7 @@ FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
...
@@ -25,7 +26,7 @@ FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
INCLUDE
:=
-I
$(CUDA_HOME)
/include
-I
$(ROOT_DIR)
/csrc
-I
$(CONDA_PREFIX)
/include
-I
$(ROOT_DIR)
/include
INCLUDE
:=
-I
$(CUDA_HOME)
/include
-I
$(ROOT_DIR)
/csrc
-I
$(CONDA_PREFIX)
/include
-I
$(ROOT_DIR)
/include
INCLUDE_10x
:=
-I
$(CUDA_HOME)
/include
-I
$(ROOT_DIR)
/csrc
-I
$(ROOT_DIR)
/dependencies/cub
-I
$(ROOT_DIR)
/include
INCLUDE_10x
:=
-I
$(CUDA_HOME)
/include
-I
$(ROOT_DIR)
/csrc
-I
$(ROOT_DIR)
/dependencies/cub
-I
$(ROOT_DIR)
/include
INCLUDE_cutlass
:=
-I
$(ROOT_DIR)
/dependencies/cutlass/include
INCLUDE_cutlass
:=
-I
$(ROOT_DIR)
/dependencies/cutlass/include
-I
$(ROOT_DIR)
/dependencies/cutlass/tools/util/include/
-I
$(ROOT_DIR)
/dependencies/cutlass/include/cute/util/
LIB
:=
-L
$(CUDA_HOME)
/lib64
-lcudart
-lcublas
-lcublasLt
-lcurand
-lcusparse
-L
$(CONDA_PREFIX)
/lib
LIB
:=
-L
$(CUDA_HOME)
/lib64
-lcudart
-lcublas
-lcublasLt
-lcurand
-lcusparse
-L
$(CONDA_PREFIX)
/lib
# NVIDIA NVCC compilation flags
# NVIDIA NVCC compilation flags
...
@@ -104,7 +105,7 @@ cuda11x: $(BUILD_DIR) env
...
@@ -104,7 +105,7 @@ cuda11x: $(BUILD_DIR) env
cuda11x_cutlass
:
$(BUILD_DIR) env cutlass
cuda11x_cutlass
:
$(BUILD_DIR) env cutlass
$(NVCC)
$(CC_cublasLt111)
-Xcompiler
'-fPIC'
--use_fast_math
-Xptxas
=
-v
-dc
$(FILES_CUDA)
$(INCLUDE)
$(INCLUDE_cutlass)
$(LIB)
--output-directory
$(BUILD_DIR)
$(NVCC)
$(CC_cublasLt111)
-Xcompiler
'-fPIC'
--use_fast_math
-Xptxas
=
-v
-dc
$(FILES_CUDA)
$(INCLUDE)
$(INCLUDE_cutlass)
$(LIB)
--output-directory
$(BUILD_DIR)
$(NVCC)
$(CC_cublasLt111)
-Xcompiler
'-fPIC'
-dlink
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
-o
$(BUILD_DIR)
/link.o
$(NVCC)
$(CC_cublasLt111)
-Xcompiler
'-fPIC'
-dlink
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
-o
$(BUILD_DIR)
/link.o
$(GPP)
-std
=
c++
20
-DBUILD_CUDA
-shared
-fPIC
$(INCLUDE)
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
$(BUILD_DIR)
/link.o
$(FILES_CPP)
-o
./bitsandbytes/libbitsandbytes_cuda
$(CUDA_VERSION)
.so
$(LIB)
$(GPP)
-std
=
c++
17
-DBUILD_CUDA
-shared
-fPIC
$(INCLUDE)
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
$(BUILD_DIR)
/link.o
$(FILES_CPP)
-o
./bitsandbytes/libbitsandbytes_cuda
$(CUDA_VERSION)
.so
$(LIB)
cuda12x
:
$(BUILD_DIR) env
cuda12x
:
$(BUILD_DIR) env
$(NVCC)
$(CC_cublasLt111)
$(CC_ADA_HOPPER)
-Xcompiler
'-fPIC'
--use_fast_math
-Xptxas
=
-v
-dc
$(FILES_CUDA)
$(INCLUDE)
$(LIB)
--output-directory
$(BUILD_DIR)
$(NVCC)
$(CC_cublasLt111)
$(CC_ADA_HOPPER)
-Xcompiler
'-fPIC'
--use_fast_math
-Xptxas
=
-v
-dc
$(FILES_CUDA)
$(INCLUDE)
$(LIB)
--output-directory
$(BUILD_DIR)
...
...
bitsandbytes/functional.py
View file @
84964db9
...
@@ -176,7 +176,7 @@ def create_custom_map(seed=0, scale=0.01):
...
@@ -176,7 +176,7 @@ def create_custom_map(seed=0, scale=0.01):
#v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207
#v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207
#v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30
#v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30
#v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293
#v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293
v
=
[
1.6072478919002173
,
1.1864907014855421
,
0.9099343314196248
,
0.6898544638558411
,
0.4990924080314459
,
0.32505049268156666
,
0.16039309503073892
]
# 0.946 24.37 22.88
#
v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88
# 7B evo start
# 7B evo start
#v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06
#v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06
...
@@ -186,7 +186,7 @@ def create_custom_map(seed=0, scale=0.01):
...
@@ -186,7 +186,7 @@ def create_custom_map(seed=0, scale=0.01):
# 13B evo start
# 13B evo start
#v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042]
#v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042]
#v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283]
#v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283]
#
v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908]
v
=
[
1.5842247437829478
,
1.2037228884260156
,
0.900369059187269
,
0.6898587137788914
,
0.4949097822874533
,
0.2959061887131868
,
0.15712393618216908
]
# mean evo 7B + 13B
# mean evo 7B + 13B
#v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237]
#v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237]
...
...
bitsandbytes/nn/modules.py
View file @
84964db9
...
@@ -228,6 +228,7 @@ class LinearNF4(Linear4bit):
...
@@ -228,6 +228,7 @@ class LinearNF4(Linear4bit):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'nf4'
)
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'nf4'
)
class
Int8Params
(
torch
.
nn
.
Parameter
):
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
def
__new__
(
cls
,
cls
,
...
...
csrc/kernels.cu
View file @
84964db9
...
@@ -12,6 +12,14 @@
...
@@ -12,6 +12,14 @@
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>
#include <cub/cub.cuh>
#include <math_constants.h>
#include <math_constants.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/cublas_wrappers.hpp"
#include "cutlass/util/helper_cuda.hpp"
#define HLF_MAX 65504
#define HLF_MAX 65504
#define TH 1024
#define TH 1024
...
@@ -2709,7 +2717,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
...
@@ -2709,7 +2717,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
}
}
}
}
#define
C
1.0f/127.0f
#define
DENORM
1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
#define SMEM_SIZE 8*256
template
<
typename
T
,
int
SPMM_ITEMS
,
int
BITS
>
template
<
typename
T
,
int
SPMM_ITEMS
,
int
BITS
>
...
@@ -2813,7 +2821,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
...
@@ -2813,7 +2821,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
float
valB
=
local_valsB
[
k
];
float
valB
=
local_valsB
[
k
];
float
valA
=
local_valA
[
i
];
float
valA
=
local_valA
[
i
];
if
(
valB
!=
0.0
&&
valA
!=
0.0
)
if
(
valB
!=
0.0
&&
valA
!=
0.0
)
local_valC
[
j
+
k
]
=
(
float
)
local_valC
[
j
+
k
]
+
((
float
)
smem_dequant_stats
[
idx
+
k
-
local_idx_col_B_offset
])
*
C
*
valB
*
valA
;
local_valC
[
j
+
k
]
=
(
float
)
local_valC
[
j
+
k
]
+
((
float
)
smem_dequant_stats
[
idx
+
k
-
local_idx_col_B_offset
])
*
DENORM
*
valB
*
valA
;
}
}
else
else
local_valC
[
j
+
k
]
=
(
float
)
local_valC
[
j
+
k
]
+
(
float
)
local_valsB
[
k
]
*
(
float
)
local_valA
[
i
];
local_valC
[
j
+
k
]
=
(
float
)
local_valC
[
j
+
k
]
+
(
float
)
local_valsB
[
k
]
*
(
float
)
local_valA
[
i
];
...
@@ -2960,7 +2968,7 @@ void
...
@@ -2960,7 +2968,7 @@ void
gemm_device
(
MShape
M
,
NShape
N
,
KShape
K
,
gemm_device
(
MShape
M
,
NShape
N
,
KShape
K
,
TA
const
*
A
,
AStride
dA
,
ABlockLayout
blockA
,
AThreadLayout
tA
,
TA
const
*
A
,
AStride
dA
,
ABlockLayout
blockA
,
AThreadLayout
tA
,
TB
const
*
B
,
BStride
dB
,
BBlockLayout
blockB
,
BThreadLayout
tB
,
TB
const
*
B
,
BStride
dB
,
BBlockLayout
blockB
,
BThreadLayout
tB
,
TC
*
C
,
CStride
dC
,
CBlockLayout
,
CThreadLayout
tC
,
TC
*
out
,
CStride
dC
,
CBlockLayout
,
CThreadLayout
tC
,
Alpha
alpha
,
Beta
beta
)
Alpha
alpha
,
Beta
beta
)
{
{
using
namespace
cute
;
using
namespace
cute
;
...
@@ -2991,7 +2999,7 @@ gemm_device(MShape M, NShape N, KShape K,
...
@@ -2991,7 +2999,7 @@ gemm_device(MShape M, NShape N, KShape K,
// Represent the full tensors
// Represent the full tensors
auto
mA
=
make_tensor
(
make_gmem_ptr
(
A
),
make_shape
(
M
,
K
),
dA
);
// (M,K)
auto
mA
=
make_tensor
(
make_gmem_ptr
(
A
),
make_shape
(
M
,
K
),
dA
);
// (M,K)
auto
mB
=
make_tensor
(
make_gmem_ptr
(
B
),
make_shape
(
N
,
K
),
dB
);
// (N,K)
auto
mB
=
make_tensor
(
make_gmem_ptr
(
B
),
make_shape
(
N
,
K
),
dB
);
// (N,K)
auto
mC
=
make_tensor
(
make_gmem_ptr
(
C
),
make_shape
(
M
,
N
),
dC
);
// (M,N)
auto
mC
=
make_tensor
(
make_gmem_ptr
(
out
),
make_shape
(
M
,
N
),
dC
);
// (M,N)
// Get the appropriate blocks for this thread block --
// Get the appropriate blocks for this thread block --
// potential for thread block locality
// potential for thread block locality
...
@@ -3034,7 +3042,6 @@ gemm_device(MShape M, NShape N, KShape K,
...
@@ -3034,7 +3042,6 @@ gemm_device(MShape M, NShape N, KShape K,
// Clear the accumulators
// Clear the accumulators
clear
(
tCrC
);
clear
(
tCrC
);
#if 1
// TUTORIAL: Example of a very simple compute loop
// TUTORIAL: Example of a very simple compute loop
// Data is read from global to shared memory via the tA|tB partitioning
// Data is read from global to shared memory via the tA|tB partitioning
...
@@ -3071,7 +3078,6 @@ gemm_device(MShape M, NShape N, KShape K,
...
@@ -3071,7 +3078,6 @@ gemm_device(MShape M, NShape N, KShape K,
__syncthreads
();
__syncthreads
();
}
}
#endif
axpby
(
alpha
,
tCrC
,
beta
,
tCgC
);
axpby
(
alpha
,
tCrC
,
beta
,
tCgC
);
}
}
...
...
csrc/ops.cu
View file @
84964db9
...
@@ -666,11 +666,9 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
...
@@ -666,11 +666,9 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include <cute/tensor.hpp>
template
<
typename
TA
,
typename
TB
,
typename
TC
,
template
<
typename
TA
,
typename
TB
,
typename
TC
,
typename
Alpha
,
typename
Beta
>
typename
Alpha
,
typename
Beta
>
void
void
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment