Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dc60d169
Commit
dc60d169
authored
Jan 09, 2019
by
Chao Liu
Browse files
adding implicit gemm
parent
05971163
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
416 additions
and
0 deletions
+416
-0
driver/device_implicit_gemm_convolution.cuh
driver/device_implicit_gemm_convolution.cuh
+120
-0
src/include/gemm.cuh
src/include/gemm.cuh
+118
-0
src/include/gridwise_implicit_gemm_convolution.cuh
src/include/gridwise_implicit_gemm_convolution.cuh
+178
-0
No files found.
driver/device_implicit_gemm_convolution.cuh
0 → 100644
View file @
dc60d169
#pragma once
#include "gridwise_implicit_gemm_convolution.cuh"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution
(
InDesc
,
const
Tensor
<
T
>&
in
,
WeiDesc
,
const
Tensor
<
T
>&
wei
,
OutDesc
,
Tensor
<
T
>&
out
)
{
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_device_buf
(
data_sz
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
data_sz
*
wei
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
data_sz
*
out
.
mDesc
.
GetElementSpace
());
int
num_thread
=
std
::
thread
::
hardware_concurrency
();
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out
.
mData
.
data
());
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
#if 1
constexpr
unsigned
OutTileSizeH
=
2
;
constexpr
unsigned
OutTileSizeW
=
2
;
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
YPerBlock
=
1
;
constexpr
unsigned
XPerBlock
=
16
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
constexpr
unsigned
OutTileSizeH
=
2
;
constexpr
unsigned
OutTileSizeW
=
2
;
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
YPerBlock
=
1
;
constexpr
unsigned
XPerBlock
=
27
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
BlockSize
=
216
;
#elif 0
constexpr
unsigned
OutTileSizeH
=
2
;
constexpr
unsigned
OutTileSizeW
=
2
;
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
YPerBlock
=
1
;
constexpr
unsigned
XPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
BlockSize
=
256
;
#endif
constexpr
unsigned
GridSize
=
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
(
OutTileSizeH
*
YPerBlock
))
*
(
out_desc
.
GetLength
(
I3
)
/
(
OutTileSizeW
*
XPerBlock
));
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
cudaEvent_t
start
,
stop
;
float
elapsedTime
;
cudaEventCreate
(
&
start
);
cudaEventRecord
(
start
,
0
);
gridwise_implicit_gemm_convolution
<
T
,
InDesc
,
WeiDesc
,
OutDesc
,
OutTileSizeH
,
OutTileSizeW
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
YPerBlock
,
XPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
BlockSize
,
GridSize
>
<<<
grid_dim
,
block_dim
>>>
(
InDesc
{},
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
WeiDesc
{},
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
OutDesc
{},
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
cudaEventCreate
(
&
stop
);
cudaEventRecord
(
stop
,
0
);
cudaEventSynchronize
(
stop
);
cudaEventElapsedTime
(
&
elapsedTime
,
start
,
stop
);
printf
(
"Elapsed time : %f ms
\n
"
,
elapsedTime
);
checkCudaErrors
(
cudaGetLastError
());
out_device_buf
.
FromDevice
(
out
.
mData
.
data
());
}
src/include/gemm.cuh
0 → 100644
View file @
dc60d169
#pragma once
template
<
class
ThreadMatrixA
,
bool
TransA
,
class
FloatA
,
class
ThreadMatrixB
,
bool
TransB
,
class
FloatB
,
class
ThreadMatrixC
,
class
FloatC
,
class
Accumulator
>
__device__
void
threadwise_gemm
(
ThreadMatrixA
,
Constant
<
bool
,
TransA
>
,
FloatA
*
const
p_a_thread
,
ThreadMatrixB
,
Constant
<
bool
,
TransB
>
,
FloatB
*
const
p_b_thread
,
ThreadMatrixC
,
Constant
<
bool
,
TransC
>
,
FloatC
*
p_c_thread
,
Accumulator
)
{
// do something
}
template
<
unsigned
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
bool
TransA
,
bool
TransB
,
unsigned
BatchSize
,
unsigned
BlockMatrixStrideA
,
unsigned
BlockMatrixStrideB
,
unsigned
BatchPerThread
,
unsigned
MPerThread
,
unsigned
NPerThread
,
unsigned
KPerThread
,
class
Accumulator
>
struct
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
{
struct
MatrixIndex
{
unsigned
batch_begin
;
unsigned
block_row_begin
;
unsigned
block_col_begin
;
};
__device__
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
()
{
static_assert
(
ThreadMatrixStrideC
>
0
,
"wrong! ThreadMatrixStrideC == 0!"
);
constexpr
auto
a_block
=
BlockMatrixA
{};
constexpr
auto
b_block
=
BlockMatrixB
{};
constexpr
auto
a_thread
=
ThreadMatrixA
{};
constexpr
auto
b_thread
=
ThreadMatrixB
{};
constexpr
auto
c_thread
=
ThreadMatrixC
{};
constexpr
unsigned
m_block
=
(
!
TransA
)
?
a_block
.
NRow
()
:
a_block
.
NCol
();
constexpr
unsigned
n_block
=
(
!
TransB
)
?
b_block
.
NCol
()
:
b_block
.
NRow
();
constexpr
unsigned
m_thread
=
(
!
TransA
)
?
a_thread
.
NRow
()
:
a_thread
.
NCol
();
constexpr
unsigned
n_thread
=
(
!
TransB
)
?
b_thread
.
NCol
()
:
b_thread
.
NRow
();
constexpr
unsigned
num_threads_per_row
=
(
m_block
+
m_thread
-
1
)
/
m_thread
;
constexpr
unsigned
num_threads_per_col
=
(
n_block
+
n_thread
-
1
)
/
n_thread
;
constexpr
unsigned
num_threads_per_batch
=
num_threads_per_row
*
num_threads_per_col
;
static_assert
(
BlockSize
>=
((
BatchSize
+
BatchPerThread
-
1
)
/
BatchPerThread
)
*
num_threads_per_batch
,
"not enough thread!"
);
const
auto
mtx_c_idnex
=
CalculateThreadMatrixCIndex
(
get_thread_local_id
());
mMyThreadOffsetA
=
xxx
;
mMyThreadoffSetB
=
xxx
;
}
__device__
MatrixIndex
CalculateThreadMatrixCIndex
(
unsigned
thread_id
)
const
{
constexpr
auto
a_block
=
BlockMatrixA
{};
constexpr
auto
b_block
=
BlockMatrixB
{};
constexpr
auto
c_block
=
BlockMatrixC
{};
constexpr
auto
a_thread
=
ThreadMatrixA
{};
constexpr
auto
b_thread
=
ThreadMatrixB
{};
constexpr
auto
c_thread
=
ThreadMatrixC
{};
constexpr
unsigned
m_block
=
(
!
TransA
)
?
a_block
.
NRow
()
:
a_block
.
NCol
();
constexpr
unsigned
n_block
=
(
!
TransB
)
?
b_block
.
NCol
()
:
b_block
.
NRow
();
constexpr
unsigned
m_thread
=
(
!
TransA
)
?
a_thread
.
NRow
()
:
a_thread
.
NCol
();
constexpr
unsigned
n_thread
=
(
!
TransB
)
?
b_thread
.
NCol
()
:
b_thread
.
NRow
();
constexpr
unsigned
num_threads_per_row
=
(
m_block
+
m_thread
-
1
)
/
m_thread
;
constexpr
unsigned
num_threads_per_col
=
(
n_block
+
n_thread
-
1
)
/
n_thread
;
constexpr
unsigned
num_threads_per_batch
=
num_threads_per_row
*
num_threads_per_col
;
// this is wrong, need fix
const
unsigned
batch_begin
=
thread_id
/
(
num_threads_per_batch
)
*
BatchPerThread
;
const
unsigned
tmp
=
thread_id
-
batch_id
*
(
num_threads_per_row
*
num_threads_per_col
);
const
unsigned
thread_matrix_row_id
=
tmp
/
num_threads_per_row
;
const
unsigned
thread_matrix_col_id
=
tmp
-
thread_matrix_row_id
*
num_threads_per_row
;
return
MatrixIndex
{
batch_begin
,
thread_matrix_row_id
*
m_thread
,
thread_matrix_col_id
*
n_thread
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
FloatA
*
const
p_a_block
,
FloatB
*
const
p_b_block
,
FloatC
*
p_c_thread
)
const
{
// do something
}
private:
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
}
src/include/gridwise_implicit_gemm_convolution.cuh
0 → 100644
View file @
dc60d169
#pragma once
#include "constant_tensor_descriptor.cuh"
#include "blockwise_tensor_op.cuh"
#include "threadwise_tensor_op.cuh"
template
<
unsigned
GridSize
,
unsigned
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
unsigned
NPerBlock
,
unsigned
KPerBlock
,
unsigned
CPerBlock
,
unsigned
HoPerBlock
,
unsigned
WoPerBlock
,
unsigned
KPerThread
,
unsigned
CPerThread
,
unsigned
HoPerThread
,
unsigned
WoPerThread
>
__global__
void
gridwise_implicit_gemm_convolution_nchw_kcsr
(
InGlobalDesc
,
Float
*
const
__restrict__
p_in_global
,
WeiGlobalDesc
,
Float
*
const
__restrict__
p_wei_global
,
OutGlobalDesc
,
Float
*
__restrict__
p_out_global
)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
constexpr
unsigned
NPerThread
=
NPerBlock
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
Constant
<
bool
,
true
>
;
constexpr
auto
False
=
Constant
<
bool
,
false
>
;
constexpr
auto
in_nchw_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_kcsr_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_nkhw_global_desc
=
OutGlobalDesc
{};
constexpr
unsigned
S
=
wei_kcsr_global_desc
.
GetLength
(
I2
);
constexpr
unsigned
R
=
wei_kcsr_global_desc
.
GetLength
(
I3
);
constexpr
unsigned
HiPerBlock
=
HoPerBlock
+
S
-
1
;
constexpr
unsigned
WiPerBlock
=
WoPerBlock
+
R
-
1
;
// block
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{});
constexpr
auto
wei_srck_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
S
,
R
,
CPerBlock
,
KPerBlock
>
{});
// LDS
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
unsigned
wei_block_size
=
wei_srck_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_size
];
__shared__
Float
p_wei_block
[
wei_block_size
];
// thread
constexpr
auto
out_hkwn_thread_desc
=
xxxxxx
();
// register
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_hkwn_thread_desc
,
p_out_thread
);
for
(
unsigned
c_block_data_begin
=
0
;
c_block_data_begin
<
in_global_desc
.
GetLength
(
I1
);
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
// input: global mem to LDS,
// convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N]
constexpr
auto
reorder_nchw2chwn
=
Sequence
<
3
,
0
,
1
,
2
>
{};
blockwise_4d_tensor_copy_reorder
<
BlockSize
>
(
in_nchw_global_desc
,
p_in_global
,
in_chwn_block_desc
,
p_in_block
,
in_chwn_block_desc
,
reorder_nchw2chwn
);
// matrix view of input
constexpr
unsigned
in_row
=
in_chwn_block_desc
.
GetLength
(
I0
);
constexpr
unsigned
in_col
=
in_chwn_block_desc
.
GetLength
(
I1
)
*
in_chwn_block_desc
.
GetLength
(
I2
)
*
in_chwn_block_desc
.
GetLength
(
I3
);
constexpr
auto
in_cxhwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
in_row
>
,
Number
<
in_col
>
,
Number
<
in_col
>
);
// weight: global mem to LDS,
// convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K]
constexpr
auto
reorder_kcsr2srck
=
Sequence
<
3
,
2
,
0
,
1
>
{};
blockwise_4d_tensor_copy_reorder
<
BlockSize
>
(
wei_csrk_global_desc
,
p_wei_global
,
wei_csrk_block_desc
,
p_wei_block
,
wei_csrk_block_desc
,
reorder_kcsr2csrk
);
// matrix view of wei
constexpr
unsigned
wei_row
=
wei_srck_block_desc
.
GetLength
(
I0
)
*
wei_srck_block_desc
.
GetLength
(
I1
)
*
wei_srck_block_desc
.
GetLength
(
I2
);
constexpr
unsigned
wei_col
=
wei_srck_block_desc
.
GetLength
(
I3
);
constexpr
auto
wei_srcxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
wei_row
>
,
Number
<
wei_col
>
,
Number
<
wei_col
>
);
__syncthreads
();
// a series of batched GEMM
// blockwise batched GEMM, C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, c_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_matrix[S*R*C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_matrix[C,Hi*Wi*N]
// C_matrix[K,Wo*N] is a sub-matrix of out_matrix[Ho*K,Wo*N]
constexpr
auto
a_block_mtx_desc
=
wei_srcxk_block_mtx_desc
.
MakeSubMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{});
constexpr
auto
b_block_mtx_desc
=
in_cxhwn_block_mtx_desc
.
MakeSubMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{});
auto
f_accum
=
(
auto
&
c
,
auto
&
v
)
{
c
+=
v
;
};
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
a_block_mtx_desc
,
b_block_mtx_desc
,
true
,
false
,
HoPerBlock
,
0
,
xxx_b_matrix_stride
,
HoPerThread
,
KPerThread
,
NPerThread
*
WoPerThread
,
CPerTrhead
,
decltype
(
f_accum
)
>
{};
// loop over filter point
for
(
unsigned
s
=
0
;
s
<
S
;
++
s
)
{
for
(
unsigned
r
=
0
;
r
<
R
;
++
r
)
{
blockwise_batch_gemm
.
run
(
p_wei_block
+
wei_srcxk_block_mtx_desc
.
Get1dIndex
(
xxxxx
,
xxxx
),
p_in_block
+
in_cxhwn_block_mtx_desc
.
Get1dIndex
(
xxxx
,
xxxx
),
p_out_thread
);
}
}
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_id
());
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
col_begin
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
row_begin
/
NPerThread
;
// output: register to global mem,
// convert matrix out_matrix[Ho*K,Wo*N] to 4d-tensor out[N,K,Ho,Wo]
constexpr
auto
reorder_hkwn2nkhw
=
Sequence
<
2
,
1
,
3
,
0
>
{};
threadwise_4d_tensor_copy_reorder
(
out_hkwn_thread_desc
,
p_out_thread
,
out_nkhw_global_desc
,
p_out_global
+
out_nkhw_global_desc
.
GetIndex
(
n_block_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_hkwn_thread_desc
,
reorder_hkwn2nkhw
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment