Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c8667736
Commit
c8667736
authored
Feb 06, 2019
by
Chao Liu
Browse files
unroll some loop, register double buffer gemm
parent
1b323316
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
158 additions
and
36 deletions
+158
-36
driver/conv.cu
driver/conv.cu
+1
-1
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh
...ice_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh
+1
-31
src/include/blockwise_gemm.cuh
src/include/blockwise_gemm.cuh
+147
-0
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
...e/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
+9
-4
No files found.
driver/conv.cu
View file @
c8667736
...
@@ -611,7 +611,7 @@ int main()
...
@@ -611,7 +611,7 @@ int main()
nrepeat
);
nrepeat
);
#endif
#endif
#if
0
#if
1
if
(
S
==
3
&&
R
==
3
)
if
(
S
==
3
&&
R
==
3
)
{
{
host_winograd_3x3_convolution
(
in_nchw
,
wei_kcsr
,
out_nkhw_host
,
lower_pads
,
upper_pads
);
host_winograd_3x3_convolution
(
in_nchw
,
wei_kcsr
,
out_nkhw_host
,
lower_pads
,
upper_pads
);
...
...
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh
View file @
c8667736
...
@@ -66,42 +66,12 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InDesc,
...
@@ -66,42 +66,12 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InDesc,
Tensor
<
T
>
out_knhw
(
make_TensorDescriptor
(
out_knhw_desc
));
Tensor
<
T
>
out_knhw
(
make_TensorDescriptor
(
out_knhw_desc
));
#if
0
#if
1
// 1x1, 28x28
// 1x1, 28x28
constexpr
unsigned
BPerBlock
=
64
;
constexpr
unsigned
BPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned GemmMPerThreadSubC = 16;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 8;
constexpr unsigned GemmMLevel1Cluster = 1;
constexpr unsigned GemmNLevel1Cluster = 2;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#elif
1
// 1x1, 28x28 try
constexpr
unsigned
BPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
unsigned
BPerThread
=
8
;
constexpr
unsigned
BPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
...
...
src/include/blockwise_gemm.cuh
View file @
c8667736
...
@@ -598,9 +598,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -598,9 +598,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
#pragma unroll
// loop over k
// loop over k
for
(
unsigned
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
for
(
unsigned
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
{
#pragma unroll
// copy A-sub to form A
// copy A-sub to form A
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
{
...
@@ -613,6 +615,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -613,6 +615,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx
.
GetLengths
());
a_thread_sub_mtx
.
GetLengths
());
}
}
#pragma unroll
// copy B-sub to form B
// copy B-sub to form B
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
{
...
@@ -638,4 +641,148 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -638,4 +641,148 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
f_accum
);
f_accum
);
}
}
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_RegisterDoubleBuffer
(
FloatA
*
const
p_a_block
,
FloatB
*
const
p_b_block
,
FloatC
*
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
Constant
<
bool
,
true
>
{};
constexpr
auto
False
=
Constant
<
bool
,
false
>
{};
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
const
auto
c_thread_mtx
=
ThreadMatrixC
{};
// constexpr doesn't compile
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
const
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
// constexpr doesn't compile
const
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// constexpr doesn't compile
// thread A-sub, B-sub for copy
const
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
// constexpr doesn't compile
const
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
// constexpr doesn't compile
FloatA
p_a_thread_0
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread_0
[
b_thread_mtx
.
GetElementSpace
()];
FloatA
p_a_thread_1
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread_1
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// preload A, B
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
// copy A-sub to form A
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
m_repeat
*
MPerLevel1Cluster
,
a_thread_sub_mtx
,
p_a_thread_0
+
m_repeat
*
MPerThreadSubC
,
a_thread_sub_mtx
.
GetLengths
());
}
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
// copy B-sub to form B
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
n_repeat
*
NPerLevel1Cluster
,
b_thread_sub_mtx
,
p_b_thread_0
+
n_repeat
*
NPerThreadSubC
,
b_thread_sub_mtx
.
GetLengths
());
}
bool
even_loop
=
true
;
#pragma unroll
for
(
unsigned
k_begin
=
0
;
k_begin
+
1
<
K
;
k_begin
+=
KPerThreadLoop
,
even_loop
=
!
even_loop
)
{
// loop over k
FloatA
*
p_a_thread_now
=
even_loop
?
p_a_thread_0
:
p_a_thread_1
;
FloatB
*
p_b_thread_now
=
even_loop
?
p_b_thread_0
:
p_b_thread_1
;
FloatA
*
p_a_thread_next
=
even_loop
?
p_a_thread_1
:
p_a_thread_0
;
FloatB
*
p_b_thread_next
=
even_loop
?
p_b_thread_1
:
p_b_thread_0
;
// preload next A, B
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
// copy A-sub to form A
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
(
k_begin
+
1
)
*
a_block_mtx
.
RowStride
()
+
m_repeat
*
MPerLevel1Cluster
,
a_thread_sub_mtx
,
p_a_thread_next
+
m_repeat
*
MPerThreadSubC
,
a_thread_sub_mtx
.
GetLengths
());
}
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
// copy B-sub to form B
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
(
k_begin
+
1
)
*
b_block_mtx
.
RowStride
()
+
n_repeat
*
NPerLevel1Cluster
,
b_thread_sub_mtx
,
p_b_thread_next
+
n_repeat
*
NPerThreadSubC
,
b_thread_sub_mtx
.
GetLengths
());
}
// C = A * B
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread_now
,
b_thread_mtx
,
False
,
p_b_thread_now
,
c_thread_mtx
,
False
,
p_c_thread
,
f_accum
);
}
// last loop
{
even_loop
=
!
even_loop
;
FloatA
*
p_a_thread_now
=
even_loop
?
p_a_thread_0
:
p_a_thread_1
;
FloatB
*
p_b_thread_now
=
even_loop
?
p_b_thread_0
:
p_b_thread_1
;
// C = A * B
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread_now
,
b_thread_mtx
,
False
,
p_b_thread_now
,
c_thread_mtx
,
False
,
p_c_thread
,
f_accum
);
}
}
};
};
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
View file @
c8667736
...
@@ -237,10 +237,15 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
...
@@ -237,10 +237,15 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
blockwise_gemm
.
Run
(
p_wei_block
+
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
#if 1
p_in_block
+
s
*
Wi
+
r
,
blockwise_gemm
.
Run
p_out_thread
,
#else
f_accum
);
blockwise_gemm
.
Run_RegisterDoubleBuffer
#endif
(
p_wei_block
+
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block
+
s
*
Wi
+
r
,
p_out_thread
,
f_accum
);
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment