Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2eeeb176
Commit
2eeeb176
authored
Jul 30, 2019
by
Chao Liu
Browse files
refactor
parent
08cbac98
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
43 deletions
+34
-43
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+32
-41
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+2
-2
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
2eeeb176
...
...
@@ -18,6 +18,8 @@ template <index_t GridSize,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
...
...
@@ -64,10 +66,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
...
...
@@ -75,10 +74,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
...
...
@@ -87,6 +84,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
...
...
@@ -95,6 +98,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
static_assert
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
,
"wrong! global vector load of input tensor is wrong"
);
static_assert
((
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
...
...
@@ -113,15 +124,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Ho
>
{})
.
Slice
(
I3
,
Number
<
Wo
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStrideH
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStrideW
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Y
>
{})
.
Slice
(
I3
,
Number
<
X
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilationH
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilationW
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
...
...
@@ -131,17 +144,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
Sequence
<
3
,
6
,
7
>
{},
Sequence
<
5
>
{});
#if 0
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_n0_n1_n2_h_w_global_desc,
"in_n0_n1_n2_h_w_global_desc: ");
print_ConstantTensorDescriptor(in_c_y_x_global_desc, "in_c_y_x_global_desc: ");
print_ConstantMergedTensorDescriptor(in_e_n1_b_n2_global_merged_desc,
"in_e_n1_b_n2_global_merged_desc: ");
}
#endif
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_block_desc
=
make_ConstantTensorDescriptor_aligned
(
...
...
@@ -206,13 +208,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_e_k_block_desc
.
GetStride
(
I0
)
>
{}
);
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
N1
*
BPerBlock
*
N2
>
{},
Number
<
in_e_n1_b_n2_block_desc
.
GetStride
(
I0
)
>
{});
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
in_e_n1_b_n2_block_desc
.
Unfold
(
I1
,
I3
));
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
...
...
@@ -242,15 +243,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// choose GEMM implementation here
const
auto
run_blockwise_gemm
=
[
&
](
auto
...
Xs
)
{
#if 1
return
blockwise_gemm
.
Run
(
Xs
...);
#else
return
blockwise_gemm
.
Run_amd_asm
(
Xs
...);
#endif
};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
...
...
@@ -281,7 +273,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
__syncthreads
();
run_
blockwise_gemm
(
p_wei_block
,
p_in_block
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
...
...
@@ -293,7 +285,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
K0
=
K
/
(
K1
*
K2
);
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
2eeeb176
...
...
@@ -59,7 +59,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
#if
0
#if
1
// each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
...
...
@@ -94,7 +94,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif
1
#elif
0
// each thread hold 32 data
constexpr
index_t
BlockSize
=
256
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment