Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
8669e242
Commit
8669e242
authored
Jul 15, 2019
by
Chao Liu
Browse files
debugging
parent
5f82fdd9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
12 deletions
+33
-12
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
+28
-7
driver/include/device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp
+4
-4
driver/src/driver.cpp
driver/src/driver.cpp
+1
-1
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
8669e242
...
...
@@ -100,7 +100,8 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
constexpr
index_t
B
=
N0
*
Ho0
*
Wo0
;
static_assert
(
N
==
N0
*
N1
*
N2
&&
Ho
==
Ho
*
Ho1
*
Ho2
&&
Wo
==
Wo0
*
Wo1
*
Wo2
,
"wrong!"
);
static_assert
(
N
==
N0
*
N1
*
N2
&&
Ho
==
Ho0
*
Ho1
*
Ho2
&&
Wo
==
Wo0
*
Wo1
*
Wo2
,
"wrong!"
);
static_assert
((
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_W2
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
...
...
@@ -179,12 +180,6 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
InBlockCopyDataPerAccess_W2
>
({
0
,
0
,
0
,
0
,
b_block_data_on_global
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
});
#if 0
{
printf("id (%d %d), in offset: %d %d\n", get_block_1d_id(), get_thread_local_1d_id(), blockwise_in_copy.mThreadSrcOffset, blockwise_in_copy.mThreadDstOffset);
}
#endif
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
...
...
@@ -214,6 +209,19 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
#if 0
if(get_block_1d_id() == 0)
{
printf("id (%d %d), in offset: %d %d, wei offset %d %d\n",
get_block_1d_id(),
get_thread_local_1d_id(),
blockwise_in_copy.mThreadSrcOffset,
blockwise_in_copy.mThreadDstOffset,
blockwise_wei_copy.mThreadSrcOffset,
blockwise_wei_copy.mThreadDstOffset);
}
#endif
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
...
...
@@ -324,6 +332,19 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
#if 1
if
(
get_block_1d_id
()
==
0
)
{
printf
(
"tid (%d %d), %f %f %f %f
\n
"
,
get_block_1d_id
(),
get_thread_local_1d_id
(),
p_wei_register_clipboard
[
0
],
p_wei_register_clipboard
[
1
],
p_wei_register_clipboard
[
2
],
p_wei_register_clipboard
[
3
]);
}
#endif
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
...
...
driver/include/device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp
View file @
8669e242
...
...
@@ -90,14 +90,14 @@ void device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
InBlockCopyDataPerAccess_W2
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
2
,
2
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
4
,
64
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
2
;
#endif
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
...
...
driver/src/driver.cpp
View file @
8669e242
...
...
@@ -491,7 +491,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
#if
0
#if
1
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
#elif 0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment