Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
6e2544da
"...text-generation-inference.git" did not exist on "bfddfa5955dd6558814d313e4364ddf534848632"
Commit
6e2544da
authored
Apr 25, 2023
by
Tim Dettmers
Browse files
Added cutlass example.
parent
6bfd7a40
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
191 additions
and
0 deletions
+191
-0
csrc/kernels.cu
csrc/kernels.cu
+134
-0
csrc/ops.cu
csrc/ops.cu
+57
-0
No files found.
csrc/kernels.cu
View file @
6e2544da
...
@@ -2942,6 +2942,140 @@ template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global
...
@@ -2942,6 +2942,140 @@ template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global
// 9. write outputs to matmul output matrix
// 9. write outputs to matmul output matrix
}
}
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
# include "cutlass/util/cublas_wrappers.hpp"
#endif
#include "cutlass/util/helper_cuda.hpp"
template
<
class
MShape
,
class
NShape
,
class
KShape
,
class
TA
,
class
AStride
,
class
ABlockLayout
,
class
AThreadLayout
,
class
TB
,
class
BStride
,
class
BBlockLayout
,
class
BThreadLayout
,
class
TC
,
class
CStride
,
class
CBlockLayout
,
class
CThreadLayout
,
class
Alpha
,
class
Beta
>
__global__
static
__launch_bounds__
(
decltype
(
size
(
CThreadLayout
{}))
::
value
)
void
gemm_device
(
MShape
M
,
NShape
N
,
KShape
K
,
TA
const
*
A
,
AStride
dA
,
ABlockLayout
blockA
,
AThreadLayout
tA
,
TB
const
*
B
,
BStride
dB
,
BBlockLayout
blockB
,
BThreadLayout
tB
,
TC
*
C
,
CStride
dC
,
CBlockLayout
,
CThreadLayout
tC
,
Alpha
alpha
,
Beta
beta
)
{
using
namespace
cute
;
using
X
=
Underscore
;
// Preconditions
CUTE_STATIC_ASSERT
(
is_static
<
ABlockLayout
>::
value
);
CUTE_STATIC_ASSERT
(
is_static
<
BBlockLayout
>::
value
);
CUTE_STATIC_ASSERT
(
is_static
<
CBlockLayout
>::
value
);
CUTE_STATIC_ASSERT
(
is_static
<
AThreadLayout
>::
value
);
CUTE_STATIC_ASSERT
(
is_static
<
BThreadLayout
>::
value
);
CUTE_STATIC_ASSERT
(
is_static
<
CThreadLayout
>::
value
);
CUTE_STATIC_ASSERT_V
(
size
(
tA
)
==
size
(
tC
));
CUTE_STATIC_ASSERT_V
(
size
(
tB
)
==
size
(
tC
));
//CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M
//CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N
CUTE_STATIC_ASSERT_V
(
shape
<
1
>
(
blockA
)
==
shape
<
1
>
(
blockB
));
// BLK_K
// Shared memory buffers
__shared__
TA
smemA
[
cosize_v
<
ABlockLayout
>
];
__shared__
TB
smemB
[
cosize_v
<
BBlockLayout
>
];
auto
sA
=
make_tensor
(
make_smem_ptr
(
smemA
),
blockA
);
// (BLK_M,BLK_K)
auto
sB
=
make_tensor
(
make_smem_ptr
(
smemB
),
blockB
);
// (BLK_N,BLK_K)
// Represent the full tensors
auto
mA
=
make_tensor
(
make_gmem_ptr
(
A
),
make_shape
(
M
,
K
),
dA
);
// (M,K)
auto
mB
=
make_tensor
(
make_gmem_ptr
(
B
),
make_shape
(
N
,
K
),
dB
);
// (N,K)
auto
mC
=
make_tensor
(
make_gmem_ptr
(
C
),
make_shape
(
M
,
N
),
dC
);
// (M,N)
// Get the appropriate blocks for this thread block --
// potential for thread block locality
auto
blk_shape
=
make_shape
(
size
<
0
>
(
sA
),
size
<
0
>
(
sB
),
size
<
1
>
(
sB
));
// (BLK_M,BLK_N,BLK_K)
auto
blk_coord
=
make_coord
(
blockIdx
.
x
,
blockIdx
.
y
,
_
);
// (m,n,k)
auto
gA
=
local_tile
(
mA
,
blk_shape
,
blk_coord
,
Step
<
_1
,
X
,
_1
>
{});
// (BLK_M,BLK_K,k)
auto
gB
=
local_tile
(
mB
,
blk_shape
,
blk_coord
,
Step
<
X
,
_1
,
_1
>
{});
// (BLK_N,BLK_K,k)
auto
gC
=
local_tile
(
mC
,
blk_shape
,
blk_coord
,
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N)
//
// Partition the copying of A and B tiles across the threads
//
// TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB
// Default is a raked partition, but can be changed with Step<X,Y> parameter
auto
tAgA
=
local_partition
(
gA
,
tA
,
threadIdx
.
x
);
// (THR_M,THR_K,k)
auto
tAsA
=
local_partition
(
sA
,
tA
,
threadIdx
.
x
);
// (THR_M,THR_K)
auto
tBgB
=
local_partition
(
gB
,
tB
,
threadIdx
.
x
);
// (THR_N,THR_K,k)
auto
tBsB
=
local_partition
(
sB
,
tB
,
threadIdx
.
x
);
// (THR_N,THR_K)
//
// Define C accumulators and A/B partitioning
//
// TUTORIAL: Example of partitioning via projections of tC
// Partition sA (M,K) by the rows of tC
auto
tCsA
=
local_partition
(
sA
,
tC
,
threadIdx
.
x
,
Step
<
_1
,
X
>
{});
// (THR_M,BLK_K)
// Partition sB (N,K) by the cols of tC
auto
tCsB
=
local_partition
(
sB
,
tC
,
threadIdx
.
x
,
Step
<
X
,
_1
>
{});
// (THR_N,BLK_K)
// Partition gC (M,N) by the tile of tC
auto
tCgC
=
local_partition
(
gC
,
tC
,
threadIdx
.
x
,
Step
<
_1
,
_1
>
{});
// (THR_M,THR_N)
// Allocate the accumulators -- same size as the projected data
auto
tCrC
=
make_fragment_like
(
tCgC
);
// (THR_M,THR_N)
// Clear the accumulators
clear
(
tCrC
);
#if 1
// TUTORIAL: Example of a very simple compute loop
// Data is read from global to shared memory via the tA|tB partitioning
// gemm(.) operates on the shared memory directly via the tC partitioning
auto
k_max
=
size
<
2
>
(
tAgA
);
for
(
int
k
=
0
;
k
<
k_max
;
++
k
)
{
// Copy gmem to smem
copy
(
tAgA
(
_
,
_
,
k
),
tAsA
);
copy
(
tBgB
(
_
,
_
,
k
),
tBsB
);
// In case copy uses cp.async, make sure that the cp.async
// instructions are ordered with respect to other cp.async
// instructions (fence), then wait on all the outstanding copy
// operations (wait<0>()). __syncthreads() alone does not do
// this.
//
// NOTE: cp_async_wait<0>() currently issues cp.async.wait_all.
// This is equivalent to cp.async.commit_group followed by
// cp.async_wait_group 0. This should make the first
// cp_async_fence() (which also issues cp.async.commit_group)
// redundant. The tutorial works as-is, so we'll leave the
// redundant fence in for now and study its removal later.
cp_async_fence
();
cp_async_wait
<
0
>
();
__syncthreads
();
// Compute gemm on smem
gemm
(
tCsA
,
tCsB
,
tCrC
);
__syncthreads
();
}
#endif
axpby
(
alpha
,
tCrC
,
beta
,
tCgC
);
}
//==============================================================
//==============================================================
// TEMPLATE DEFINITIONS
// TEMPLATE DEFINITIONS
...
...
csrc/ops.cu
View file @
6e2544da
...
@@ -665,6 +665,63 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
...
@@ -665,6 +665,63 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
}
}
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
template
<
typename
TA
,
typename
TB
,
typename
TC
,
typename
Alpha
,
typename
Beta
>
void
gemm
(
int
m
,
int
n
,
int
k
,
Alpha
alpha
,
TA
const
*
A
,
int
ldA
,
TB
const
*
B
,
int
ldB
,
Beta
beta
,
TC
*
C
,
int
ldC
,
cudaStream_t
stream
=
0
)
{
using
namespace
cute
;
// Define shapes (dynamic)
auto
M
=
int
(
m
);
auto
N
=
int
(
n
);
auto
K
=
int
(
k
);
// Define strides (mixed)
auto
dA
=
make_stride
(
Int
<
1
>
{},
ldA
);
auto
dB
=
make_stride
(
Int
<
1
>
{},
ldB
);
auto
dC
=
make_stride
(
Int
<
1
>
{},
ldC
);
// Define block sizes (static)
auto
bM
=
Int
<
128
>
{};
auto
bN
=
Int
<
128
>
{};
auto
bK
=
Int
<
8
>
{};
// Define the block layouts (static)
auto
sA
=
make_layout
(
make_shape
(
bM
,
bK
));
auto
sB
=
make_layout
(
make_shape
(
bN
,
bK
));
auto
sC
=
make_layout
(
make_shape
(
bM
,
bN
));
// Define the thread layouts (static)
auto
tA
=
make_layout
(
make_shape
(
Int
<
32
>
{},
Int
<
8
>
{}));
auto
tB
=
make_layout
(
make_shape
(
Int
<
32
>
{},
Int
<
8
>
{}));
auto
tC
=
make_layout
(
make_shape
(
Int
<
16
>
{},
Int
<
16
>
{}));
dim3
dimBlock
(
size
(
tC
));
dim3
dimGrid
(
ceil_div
(
size
(
M
),
size
(
bM
)),
ceil_div
(
size
(
N
),
size
(
bN
)));
gemm_device
<<<
dimGrid
,
dimBlock
,
0
,
stream
>>>
(
M
,
N
,
K
,
A
,
dA
,
sA
,
tA
,
B
,
dB
,
sB
,
tB
,
C
,
dC
,
sC
,
tC
,
alpha
,
beta
);
}
//==============================================================
//==============================================================
// TEMPLATE DEFINITIONS
// TEMPLATE DEFINITIONS
//==============================================================
//==============================================================
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment