Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
84d9802d
Commit
84d9802d
authored
Jan 15, 2019
by
Chao Liu
Browse files
adding implicit gemm
parent
aa0199a3
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
261 additions
and
162 deletions
+261
-162
driver/conv.cu
driver/conv.cu
+1
-1
driver/device_implicit_gemm_convolution.cuh
driver/device_implicit_gemm_convolution.cuh
+4
-4
src/include/ConstantMatrixDescriptor.cuh
src/include/ConstantMatrixDescriptor.cuh
+46
-0
src/include/ConstantTensorDescriptor.cuh
src/include/ConstantTensorDescriptor.cuh
+0
-64
src/include/blockwise_direct_convolution.cuh
src/include/blockwise_direct_convolution.cuh
+1
-1
src/include/blockwise_tensor_op.cuh
src/include/blockwise_tensor_op.cuh
+1
-1
src/include/common.cuh
src/include/common.cuh
+69
-1
src/include/conv_common.cuh
src/include/conv_common.cuh
+1
-1
src/include/gemm.cuh
src/include/gemm.cuh
+40
-25
src/include/gridwise_direct_convolution_1.cuh
src/include/gridwise_direct_convolution_1.cuh
+1
-1
src/include/gridwise_direct_convolution_2.cuh
src/include/gridwise_direct_convolution_2.cuh
+1
-1
src/include/gridwise_implicit_gemm_convolution.cuh
src/include/gridwise_implicit_gemm_convolution.cuh
+93
-59
src/include/gridwise_winograd_convolution.cuh
src/include/gridwise_winograd_convolution.cuh
+1
-1
src/include/threadwise_direct_convolution.cuh
src/include/threadwise_direct_convolution.cuh
+1
-1
src/include/threadwise_tensor_op.cuh
src/include/threadwise_tensor_op.cuh
+1
-1
No files found.
driver/conv.cu
View file @
84d9802d
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include <cstdlib>
#include <cstdlib>
#include "nvToolsExt.h"
#include "nvToolsExt.h"
#include "tensor.hpp"
#include "tensor.hpp"
#include "
c
onstant
_t
ensor
_d
escriptor.cuh"
#include "
C
onstant
T
ensor
D
escriptor.cuh"
#include "conv_common.cuh"
#include "conv_common.cuh"
#include "device_direct_convolution_1.cuh"
#include "device_direct_convolution_1.cuh"
#include "device_direct_convolution_2.cuh"
#include "device_direct_convolution_2.cuh"
...
...
driver/device_implicit_gemm_convolution.cuh
View file @
84d9802d
...
@@ -27,14 +27,14 @@ void device_implicit_gemm_convolution(
...
@@ -27,14 +27,14 @@ void device_implicit_gemm_convolution(
#if 1
#if 1
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
4
;
constexpr
unsigned
WoPerThread
=
4
;
...
...
src/include/ConstantMatrixDescriptor.cuh
0 → 100644
View file @
84d9802d
#pragma once
#include "common.cuh"
template
<
unsigned
NRow
,
unsigned
NCol
,
unsigned
RowStride
>
struct
ConstantMatrixDescriptor
{
__host__
__device__
ConstantMatrixDescriptor
()
{
static_assert
(
NCol
<=
RowStride
,
"wrong! NCol > RowStride!"
);
}
__host__
__device__
constexpr
unsigned
GetNumberOfRow
()
const
{
return
NRow
;
}
__host__
__device__
constexpr
unsigned
GetNumberOfColumn
()
const
{
return
NCol
;
}
__host__
__device__
constexpr
unsigned
GetRowStride
()
const
{
return
RowStride
;
}
__host__
__device__
constexpr
unsigned
GetElementSize
()
const
{
return
NRow
*
NCol
;
}
__host__
__device__
constexpr
unsigned
GetElementSpace
()
const
{
return
NRow
*
RowStride
;
}
__host__
__device__
unsigned
Get1dIndex
(
unsigned
irow
,
unsigned
icol
)
const
{
return
irow
*
RowStride
+
icol
;
}
template
<
unsigned
SubNRow
,
unsigned
SubNCol
>
__host__
__device__
constexpr
auto
MakeSubMatrixDescriptor
(
Number
<
SubNRow
>
,
Number
<
SubNCol
>
)
const
{
return
ConstantMatrixDescriptor
<
SubNRow
,
SubNCol
,
RowStride
>
{};
}
};
template
<
unsigned
NRow
,
unsigned
NCol
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
Number
<
NRow
>
,
Number
<
NCol
>
)
{
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
NCol
>
{};
}
template
<
unsigned
NRow
,
unsigned
NCol
,
unsigned
RowStride
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
Number
<
NRow
>
,
Number
<
NCol
>
,
Number
<
RowStride
>
)
{
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
RowStride
>
{};
}
src/include/
c
onstant
_t
ensor
_d
escriptor.cuh
→
src/include/
C
onstant
T
ensor
D
escriptor.cuh
View file @
84d9802d
#pragma once
#pragma once
#include "common.cuh"
#include "common.cuh"
template
<
class
T
,
T
N
>
struct
Constant
{
static
const
T
mValue
=
N
;
};
template
<
unsigned
N
>
using
Number
=
Constant
<
unsigned
,
N
>
;
template
<
unsigned
...
Is
>
struct
Sequence
{
static
constexpr
unsigned
nDim
=
sizeof
...(
Is
);
const
unsigned
mData
[
nDim
]
=
{
Is
...};
template
<
unsigned
I
>
__host__
__device__
constexpr
unsigned
Get
(
Number
<
I
>
)
const
{
return
mData
[
I
];
}
template
<
unsigned
I0
,
unsigned
I1
>
__host__
__device__
constexpr
auto
Reorder
(
Number
<
I0
>
,
Number
<
I1
>
)
const
{
constexpr
unsigned
IR0
=
Get
(
Number
<
I0
>
{});
constexpr
unsigned
IR1
=
Get
(
Number
<
I1
>
{});
return
Sequence
<
IR0
,
IR1
>
{};
}
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
>
__host__
__device__
constexpr
auto
Reorder
(
Number
<
I0
>
,
Number
<
I1
>
,
Number
<
I2
>
)
const
{
constexpr
unsigned
IR0
=
Get
(
Number
<
I0
>
{});
constexpr
unsigned
IR1
=
Get
(
Number
<
I1
>
{});
constexpr
unsigned
IR2
=
Get
(
Number
<
I2
>
{});
return
Sequence
<
IR0
,
IR1
,
IR2
>
{};
}
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
__host__
__device__
constexpr
auto
Reorder
(
Number
<
I0
>
,
Number
<
I1
>
,
Number
<
I2
>
,
Number
<
I3
>
)
const
{
constexpr
unsigned
IR0
=
Get
(
Number
<
I0
>
{});
constexpr
unsigned
IR1
=
Get
(
Number
<
I1
>
{});
constexpr
unsigned
IR2
=
Get
(
Number
<
I2
>
{});
constexpr
unsigned
IR3
=
Get
(
Number
<
I3
>
{});
return
Sequence
<
IR0
,
IR1
,
IR2
,
IR3
>
{};
}
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
__host__
__device__
constexpr
auto
Reorder
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
{
constexpr
unsigned
IR0
=
Get
(
Number
<
I0
>
{});
constexpr
unsigned
IR1
=
Get
(
Number
<
I1
>
{});
constexpr
unsigned
IR2
=
Get
(
Number
<
I2
>
{});
constexpr
unsigned
IR3
=
Get
(
Number
<
I3
>
{});
return
Sequence
<
IR0
,
IR1
,
IR2
,
IR3
>
{};
}
};
template
<
class
Lengths
,
class
Strides
>
template
<
class
Lengths
,
class
Strides
>
struct
ConstantTensorDescriptor
struct
ConstantTensorDescriptor
{
{
...
...
src/include/blockwise_direct_convolution.cuh
View file @
84d9802d
#pragma once
#pragma once
#include "
c
onstant
_t
ensor
_d
escriptor.cuh"
#include "
C
onstant
T
ensor
D
escriptor.cuh"
#include "threadwise_tensor_op.cuh"
#include "threadwise_tensor_op.cuh"
#include "threadwise_direct_convolution.cuh"
#include "threadwise_direct_convolution.cuh"
...
...
src/include/blockwise_tensor_op.cuh
View file @
84d9802d
#pragma once
#pragma once
#include "
c
onstant
_t
ensor
_d
escriptor.cuh"
#include "
C
onstant
T
ensor
D
escriptor.cuh"
template
<
unsigned
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
template
<
unsigned
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
__device__
void
__device__
void
...
...
src/include/common.cuh
View file @
84d9802d
...
@@ -12,4 +12,72 @@ struct is_same<T, T>
...
@@ -12,4 +12,72 @@ struct is_same<T, T>
static
const
bool
value
=
true
;
static
const
bool
value
=
true
;
};
};
__device__
unsigned
get_thread_local_id
()
{
return
threadIdx
.
x
;
}
__device__
unsigned
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
unsigned
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
template
<
class
T
,
T
N
>
struct
Constant
{
static
const
T
mValue
=
N
;
__host__
__device__
constexpr
T
Get
()
const
{
return
mValue
;
}
};
template
<
unsigned
N
>
using
Number
=
Constant
<
unsigned
,
N
>
;
template
<
unsigned
...
Is
>
struct
Sequence
{
static
constexpr
unsigned
nDim
=
sizeof
...(
Is
);
const
unsigned
mData
[
nDim
]
=
{
Is
...};
template
<
unsigned
I
>
__host__
__device__
constexpr
unsigned
Get
(
Number
<
I
>
)
const
{
return
mData
[
I
];
}
template
<
unsigned
I0
,
unsigned
I1
>
__host__
__device__
constexpr
auto
Reorder
(
Number
<
I0
>
,
Number
<
I1
>
)
const
{
constexpr
unsigned
IR0
=
Get
(
Number
<
I0
>
{});
constexpr
unsigned
IR1
=
Get
(
Number
<
I1
>
{});
return
Sequence
<
IR0
,
IR1
>
{};
}
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
>
__host__
__device__
constexpr
auto
Reorder
(
Number
<
I0
>
,
Number
<
I1
>
,
Number
<
I2
>
)
const
{
constexpr
unsigned
IR0
=
Get
(
Number
<
I0
>
{});
constexpr
unsigned
IR1
=
Get
(
Number
<
I1
>
{});
constexpr
unsigned
IR2
=
Get
(
Number
<
I2
>
{});
return
Sequence
<
IR0
,
IR1
,
IR2
>
{};
}
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
__host__
__device__
constexpr
auto
Reorder
(
Number
<
I0
>
,
Number
<
I1
>
,
Number
<
I2
>
,
Number
<
I3
>
)
const
{
constexpr
unsigned
IR0
=
Get
(
Number
<
I0
>
{});
constexpr
unsigned
IR1
=
Get
(
Number
<
I1
>
{});
constexpr
unsigned
IR2
=
Get
(
Number
<
I2
>
{});
constexpr
unsigned
IR3
=
Get
(
Number
<
I3
>
{});
return
Sequence
<
IR0
,
IR1
,
IR2
,
IR3
>
{};
}
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
__host__
__device__
constexpr
auto
Reorder
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
{
constexpr
unsigned
IR0
=
Get
(
Number
<
I0
>
{});
constexpr
unsigned
IR1
=
Get
(
Number
<
I1
>
{});
constexpr
unsigned
IR2
=
Get
(
Number
<
I2
>
{});
constexpr
unsigned
IR3
=
Get
(
Number
<
I3
>
{});
return
Sequence
<
IR0
,
IR1
,
IR2
,
IR3
>
{};
}
};
src/include/conv_common.cuh
View file @
84d9802d
#pragma once
#pragma once
#include "
c
onstant
_t
ensor
_d
escriptor.cuh"
#include "
C
onstant
T
ensor
D
escriptor.cuh"
// this is ugly, only for 4d
// this is ugly, only for 4d
template
<
class
InDesc
,
class
WeiDesc
>
template
<
class
InDesc
,
class
WeiDesc
>
...
...
src/include/gemm.cuh
View file @
84d9802d
#pragma once
#pragma once
template
<
class
ThreadMatrixA
,
template
<
class
ThreadMatrixA
,
bool
TransA
,
class
FloatA
,
class
ThreadMatrixB
,
class
ThreadMatrixB
,
class
ThreadMatrixC
,
bool
TransA
,
bool
TransB
,
bool
TransB
,
bool
TransC
,
class
FloatA
,
class
FloatB
,
class
FloatB
,
class
ThreadMatrixC
,
class
FloatC
,
class
FloatC
,
class
Accumulator
>
class
Accumulator
>
__device__
void
threadwise_gemm
(
ThreadMatrixA
,
__device__
void
threadwise_gemm
(
ThreadMatrixA
,
...
@@ -26,41 +27,51 @@ __device__ void threadwise_gemm(ThreadMatrixA,
...
@@ -26,41 +27,51 @@ __device__ void threadwise_gemm(ThreadMatrixA,
template
<
unsigned
BlockSize
,
template
<
unsigned
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
bool
TransA
,
bool
TransA
,
bool
TransB
,
bool
TransB
,
unsigned
BatchSize
,
bool
TransC
,
unsigned
BlockMatrixStrideA
,
unsigned
BlockMatrixStrideA
,
unsigned
BlockMatrixStrideB
,
unsigned
BlockMatrixStrideB
,
unsigned
ThreadMatrixStrideC
,
unsigned
BatchSize
,
unsigned
BatchPerThread
,
unsigned
BatchPerThread
,
unsigned
MPerThread
,
unsigned
KPerLoop
,
unsigned
NPerThread
,
unsigned
KPerThread
,
class
Accumulator
>
class
Accumulator
>
struct
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
struct
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
{
{
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
struct
MatrixIndex
{
{
unsigned
batch_begin
;
unsigned
batch_begin
;
unsigned
block_
row_begin
;
unsigned
row_begin
;
unsigned
block_
col_begin
;
unsigned
col_begin
;
};
};
__device__
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
()
__device__
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
()
{
{
static_assert
(
ThreadMatrixStrideC
>
0
,
"wrong! ThreadMatrixStrideC == 0!"
);
static_assert
(
ThreadMatrixStrideC
>
0
,
"wrong! ThreadMatrixStrideC == 0!"
);
constexpr
auto
a_block
=
BlockMatrixA
{};
#if 0
constexpr
auto
b_block
=
BlockMatrixB
{};
constexpr auto a_block_desc = BlockMatrixA{};
constexpr auto b_block_desc = BlockMatrixB{};
constexpr
auto
a_thread
=
ThreadMatrixA
{};
constexpr unsigned a_thread_row = (!TransA) ? MPerThread : KPerThread;
constexpr
auto
b_thread
=
ThreadMatrixB
{};
constexpr unsigned a_thread_col = (!TransA) ? KPerThread : MPerThread;
constexpr
auto
c_thread
=
ThreadMatrixC
{};
constexpr unsigned b_thread_row = (!TransB) ? KPerThread : NPerThread;
constexpr unsigned b_thread_col = (!TransB) ? NPerThread : KPerThread;
constexpr
unsigned
m_block
=
(
!
TransA
)
?
a_block
.
NRow
()
:
a_block
.
NCol
();
constexpr auto a_thread_desc = ConstantMatrixDescriptor<a_thread_row, a_thread_col>{};
constexpr
unsigned
n_block
=
(
!
TransB
)
?
b_block
.
NCol
()
:
b_block
.
NRow
();
constexpr auto b_thread_desc = ConstantMatrixDescriptor<b_thread_row, b_thread_col>{};
constexpr auto c_thread_desc = ConstantMatrixDescriptor<MPerThread, NPerThread>{};
constexpr
unsigned
m_thread
=
(
!
TransA
)
?
a_thread
.
NRow
()
:
a_thread
.
NCol
();
constexpr unsigned m_block = (!TransA) ? a_block_desc.NRow() : a_block_desc.NCol();
constexpr
unsigned
n_thread
=
(
!
TransB
)
?
b_thread
.
NCol
()
:
b_thread
.
NRow
();
constexpr unsigned n_block = (!TransB) ? b_block_desc.NCol() : b_block_desc.NRow();
constexpr unsigned m_thread = (!TransA) ? a_thread_desc.NRow() : a_thread_desc.NCol();
constexpr unsigned n_thread = (!TransB) ? b_thread_desc.NCol() : b_thread_desc.NRow();
constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread;
constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread;
constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread;
constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread;
...
@@ -72,12 +83,17 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
...
@@ -72,12 +83,17 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
const auto mtx_c_idnex = CalculateThreadMatrixCIndex(get_thread_local_id());
const auto mtx_c_idnex = CalculateThreadMatrixCIndex(get_thread_local_id());
mMyThreadOffsetA
=
xxx
;
// mMyThreadOffsetA = xxx;
mMyThreadoffSetB
=
xxx
;
// mMyThreadoffSetB = xxx;
#else
mMyThreadOffsetA
=
0
;
mMyThreadOffsetB
=
0
;
#endif
}
}
__device__
MatrixIndex
CalculateThreadMatrixCIndex
(
unsigned
thread_id
)
const
__device__
MatrixIndex
CalculateThreadMatrixCIndex
(
unsigned
thread_id
)
const
{
{
#if 0
constexpr auto a_block = BlockMatrixA{};
constexpr auto a_block = BlockMatrixA{};
constexpr auto b_block = BlockMatrixB{};
constexpr auto b_block = BlockMatrixB{};
constexpr auto c_block = BlockMatrixC{};
constexpr auto c_block = BlockMatrixC{};
...
@@ -104,6 +120,9 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
...
@@ -104,6 +120,9 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
return MatrixIndex{
return MatrixIndex{
batch_begin, thread_matrix_row_id * m_thread, thread_matrix_col_id * n_thread};
batch_begin, thread_matrix_row_id * m_thread, thread_matrix_col_id * n_thread};
#else
return
MatrixIndex
{
0
,
0
,
0
};
#endif
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
...
@@ -111,8 +130,4 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
...
@@ -111,8 +130,4 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
{
{
// do something
// do something
}
}
};
private:
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
}
src/include/gridwise_direct_convolution_1.cuh
View file @
84d9802d
#pragma once
#pragma once
#include "
c
onstant
_t
ensor
_d
escriptor.cuh"
#include "
C
onstant
T
ensor
D
escriptor.cuh"
#include "blockwise_tensor_op.cuh"
#include "blockwise_tensor_op.cuh"
#include "blockwise_direct_convolution.cuh"
#include "blockwise_direct_convolution.cuh"
...
...
src/include/gridwise_direct_convolution_2.cuh
View file @
84d9802d
#pragma once
#pragma once
#include "
c
onstant
_t
ensor
_d
escriptor.cuh"
#include "
C
onstant
T
ensor
D
escriptor.cuh"
#include "blockwise_tensor_op.cuh"
#include "blockwise_tensor_op.cuh"
#include "blockwise_direct_convolution.cuh"
#include "blockwise_direct_convolution.cuh"
#include "threadwise_tensor_op.cuh"
#include "threadwise_tensor_op.cuh"
...
...
src/include/gridwise_implicit_gemm_convolution.cuh
View file @
84d9802d
#pragma once
#pragma once
#include "constant_tensor_descriptor.cuh"
#include "common.cuh"
#include "ConstantTensorDescriptor.cuh"
#include "ConstantMatrixDescriptor.cuh"
#include "blockwise_tensor_op.cuh"
#include "blockwise_tensor_op.cuh"
#include "threadwise_tensor_op.cuh"
#include "threadwise_tensor_op.cuh"
#include "gemm.cuh"
template
<
unsigned
GridSize
,
template
<
unsigned
GridSize
,
unsigned
BlockSize
,
unsigned
BlockSize
,
...
@@ -45,59 +48,85 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
...
@@ -45,59 +48,85 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
constexpr
unsigned
HiPerBlock
=
HoPerBlock
+
S
-
1
;
constexpr
unsigned
HiPerBlock
=
HoPerBlock
+
S
-
1
;
constexpr
unsigned
WiPerBlock
=
WoPerBlock
+
R
-
1
;
constexpr
unsigned
WiPerBlock
=
WoPerBlock
+
R
-
1
;
// tensor view of blockwise input and weight in LDS
// divide block work: NCHW
constexpr
auto
in_chwn_block_desc
=
constexpr
unsigned
NBlockWork
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{});
(
out_nkhw_global_desc
.
GetLength
(
I0
)
+
NPerBlock
-
1
)
/
NPerBlock
;
constexpr
unsigned
KBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I1
)
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
unsigned
HBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I2
)
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
unsigned
WBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I3
)
+
WoPerBlock
-
1
)
/
WoPerBlock
;
unsigned
itmp
=
get_block_1d_id
();
const
unsigned
n_block_work_id
=
itmp
/
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
itmp
-=
n_block_work_id
*
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
const
unsigned
k_block_work_id
=
itmp
/
(
HBlockWork
*
WBlockWork
);
itmp
-=
k_block_work_id
*
(
HBlockWork
*
WBlockWork
);
const
unsigned
h_block_work_id
=
itmp
/
WBlockWork
;
const
unsigned
w_block_work_id
=
itmp
-
h_block_work_id
*
WBlockWork
;
const
unsigned
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
unsigned
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
unsigned
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
unsigned
wo_block_data_begin
=
w_block_work_id
*
HoPerBlock
;
const
unsigned
hi_block_data_begin
=
ho_block_data_begin
;
const
unsigned
wi_block_data_begin
=
wo_block_data_begin
;
// tensor view of blockwise input and weight in LDS
constexpr
auto
wei_srck_block_desc
=
constexpr
auto
wei_srck_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
S
,
R
,
CPerBlock
,
KPerBlock
>
{});
make_ConstantTensorDescriptor
(
Sequence
<
S
,
R
,
CPerBlock
,
KPerBlock
>
{});
// matrix view of blockwise input and weight in LDS
constexpr
auto
in_chwn_block_desc
=
constexpr
auto
in_cxhwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{});
Number
<
CPerBlock
>
,
Number
<
HiPerBlock
*
WiPerBlock
*
NPerBlock
>
);
constexpr
auto
wei_srcxk_block_mtx_desc
=
// tensor view of threadwise output in register
make_ConstantMatrixDescriptor
(
Number
<
S
*
R
*
CPerBlock
>
,
Number
<
KPerBlock
>
);
constexpr
auto
out_hkwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
HoPerThread
,
KPerThread
,
WoPerThread
,
NPerThread
>
{});
// LDS
// a series of blockwise batched GEMM
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
// C_matrix += transpose(A_matrix) * B_matrix
constexpr
unsigned
wei_block_size
=
wei_srck_block_desc
.
GetElementSpace
();
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
const
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{});
// constexpr doesn't compile
__shared__
Float
p_in_block
[
in_block_size
];
const
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
__shared__
Float
p_wei_block
[
wei_block_size
];
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I1
)
>
{});
// constexpr doesn't compile
// a series of batched GEMM
const
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
// blockwise batched GEMM, C_matrix += transpose(A_matrix) * B_matrix
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
// A_matrix and B_matrix saved in LDS, c_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_matrix[S*R*C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_matrix[C,Hi*Wi*N]
// C_matrix[K,Wo*N] is a sub-matrix of out_matrix[Ho*K,Wo*N]
constexpr
auto
a_block_mtx_desc
=
wei_srcxk_block_mtx_desc
.
MakeSubMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{});
constexpr
auto
b_block_mtx_desc
=
in_cxhwn_block_mtx_desc
.
MakeSubMatrixDescriptor
(
auto
f_accum
=
[](
auto
&
c
,
auto
&
ab
)
{
c
+=
ab
;
};
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{});
auto
f_accum
=
(
auto
&
c
,
auto
&
v
)
{
c
+=
v
;
};
const
auto
blockwise_batch_gemm
=
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
a_block_mtx_desc
,
decltype
(
a_cxk_block_mtx_desc
),
b_block_mtx_desc
,
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
true
,
true
,
false
,
false
,
HoPerBlock
,
false
,
0
,
0
,
xxx_b_matrix_stride
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
HoPerThread
,
HoPerThread
,
KPerThread
,
CPerThread
,
NPerThread
*
WoPerThread
,
CPerTrhead
,
decltype
(
f_accum
)
>
{};
decltype
(
f_accum
)
>
{};
// tensor view of threadwise output in register
// LDS
constexpr
auto
out_hkwn_thread_desc
=
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
make_ConstantTensorDescriptor
(
Sequence
<
HoPerThread
,
KPerThread
,
WoPerThread
,
NPerThread
>
{});
constexpr
unsigned
wei_block_size
=
wei_srck_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_size
];
__shared__
Float
p_wei_block
[
wei_block_size
];
// register
// register
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
...
@@ -105,48 +134,53 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
...
@@ -105,48 +134,53 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
// set threadwise output tensor to 0
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_hkwn_thread_desc
,
p_out_thread
);
threadwise_4d_tensor_set_zero
(
out_hkwn_thread_desc
,
p_out_thread
);
for
(
unsigned
c_block_data_begin
=
0
;
c_block_data_begin
<
in_global_desc
.
GetLength
(
I1
);
for
(
unsigned
c_block_data_begin
=
0
;
c_block_data_begin
<
in_
nchw_
global_desc
.
GetLength
(
I1
);
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
{
// input: global mem to LDS,
// input: global mem to LDS,
// convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N]
// convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N]
constexpr
auto
reorder_nchw2chwn
=
Sequence
<
3
,
0
,
1
,
2
>
{};
constexpr
auto
reorder_nchw2chwn
=
Sequence
<
3
,
0
,
1
,
2
>
{};
blockwise_4d_tensor_copy_reorder
<
BlockSize
>
(
in_nchw_global_desc
,
blockwise_4d_tensor_copy_reorder
<
BlockSize
>
(
p_in_global
,
in_nchw_global_desc
,
in_chwn_block_desc
,
p_in_global
+
in_nchw_global_desc
.
Get1dIndex
(
n_block_data_begin
,
p_in_block
,
c_block_data_begin
,
in_chwn_block_desc
,
hi_block_data_begin
,
reorder_nchw2chwn
);
wi_block_data_begin
),
in_chwn_block_desc
,
p_in_block
,
in_chwn_block_desc
,
reorder_nchw2chwn
);
// weight: global mem to LDS,
// weight: global mem to LDS,
// convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K]
// convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K]
constexpr
auto
reorder_kcsr2srck
=
Sequence
<
3
,
2
,
0
,
1
>
{};
constexpr
auto
reorder_kcsr2srck
=
Sequence
<
3
,
2
,
0
,
1
>
{};
blockwise_4d_tensor_copy_reorder
<
BlockSize
>
(
wei_csrk_global_desc
,
blockwise_4d_tensor_copy_reorder
<
BlockSize
>
(
p_wei_global
,
wei_kcsr_global_desc
,
wei_csrk_block_desc
,
p_wei_global
+
p_wei_block
,
wei_kcsr_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
wei_csrk_block_desc
,
wei_srck_block_desc
,
reorder_kcsr2csrk
);
p_wei_block
,
wei_srck_block_desc
,
reorder_kcsr2srck
);
__syncthreads
();
__syncthreads
();
//
loop over filter point
//
a series of batched GEMM
for
(
unsigned
s
=
0
;
s
<
S
;
++
s
)
for
(
unsigned
s
=
0
;
s
<
S
;
++
s
)
{
{
for
(
unsigned
r
=
0
;
r
<
R
;
++
r
)
for
(
unsigned
r
=
0
;
r
<
R
;
++
r
)
{
{
blockwise_batch_gemm
.
run
(
blockwise_batch_gemm
.
run
(
p_wei_block
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
p_wei_block
+
wei_srcxk_block_mtx_desc
.
Get1dIndex
(
xxxxx
,
xxxx
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
0
,
r
,
0
),
p_in_block
+
in_cxhwn_block_mtx_desc
.
Get1dIndex
(
xxxx
,
xxxx
),
p_out_thread
);
p_out_thread
);
}
}
}
}
}
}
const
auto
matrix_c_index
=
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_id
());
blockwise_batch_gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_
1d_
id
());
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
col_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
col_begin
;
...
@@ -160,10 +194,10 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
...
@@ -160,10 +194,10 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
out_hkwn_thread_desc
,
out_hkwn_thread_desc
,
p_out_thread
,
p_out_thread
,
out_nkhw_global_desc
,
out_nkhw_global_desc
,
p_out_global
+
out_nkhw_global_desc
.
GetIndex
(
n_block_data_begin
,
p_out_global
+
out_nkhw_global_desc
.
Get
1d
Index
(
n_block_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
out_hkwn_thread_desc
,
out_hkwn_thread_desc
,
reorder_hkwn2nkhw
);
reorder_hkwn2nkhw
);
}
}
src/include/gridwise_winograd_convolution.cuh
View file @
84d9802d
#pragma once
#pragma once
#include "
c
onstant
_t
ensor
_d
escriptor.cuh"
#include "
C
onstant
T
ensor
D
escriptor.cuh"
#include "blockwise_winograd_transform.cuh"
#include "blockwise_winograd_transform.cuh"
#include "threadwise_winograd_transform.cuh"
#include "threadwise_winograd_transform.cuh"
...
...
src/include/threadwise_direct_convolution.cuh
View file @
84d9802d
#pragma once
#pragma once
#include "
c
onstant
_t
ensor
_d
escriptor.cuh"
#include "
C
onstant
T
ensor
D
escriptor.cuh"
// optimized for scenario if p_in, p_wei, p_out are in register
// optimized for scenario if p_in, p_wei, p_out are in register
template
<
class
Float
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
Float
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
...
...
src/include/threadwise_tensor_op.cuh
View file @
84d9802d
#pragma once
#pragma once
#include "
c
onstant
_t
ensor
_d
escriptor.cuh"
#include "
C
onstant
T
ensor
D
escriptor.cuh"
template
<
class
Float
,
class
Desc
,
class
F
>
template
<
class
Float
,
class
Desc
,
class
F
>
__device__
void
threadwise_4d_tensor_pointwise_operation_unary
(
Desc
,
Float
*
__restrict__
p
,
F
f
)
__device__
void
threadwise_4d_tensor_pointwise_operation_unary
(
Desc
,
Float
*
__restrict__
p
,
F
f
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment