Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e38ee30a
Commit
e38ee30a
authored
Jan 16, 2020
by
Chao Liu
Browse files
tweaking
parent
91e0de2e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
30 deletions
+42
-30
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+6
-2
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
...a_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
+16
-8
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+20
-20
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
e38ee30a
...
...
@@ -126,7 +126,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM: atomic add
// GEMM
constexpr
auto
in_memory_op
=
(
Y
<=
ConvStrideH
&&
X
<=
ConvStrideW
)
?
InMemoryDataOperation
::
none
:
InMemoryDataOperation
::
atomic_add
;
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
...
...
@@ -135,7 +139,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
decltype
(
wei_k_e_global_desc
),
decltype
(
out_k_b_global_desc
),
decltype
(
in_e_b_global_desc
),
InM
emory
DataOperation
::
atomic_add
,
in_m
emory
_op
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
e38ee30a
...
...
@@ -352,8 +352,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
}
}
// input: register to global memory, atomic add
{
#if 1 // debug
// input: register to global memory, atomic add
constexpr
auto
in_memory_op
=
(
Y
<=
ConvStrideH
&&
X
<=
ConvStrideW
)
?
InMemoryDataOperation
::
none
:
InMemoryDataOperation
::
atomic_add
;
#else
constexpr
auto
in_memory_op
=
InMemoryDataOperation
::
atomic_add
;
#endif
constexpr
index_t
E1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
E0
=
E
/
E1
;
...
...
@@ -426,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
InThreadCopyDstDataPerWrite_B
,
AddressSpace
::
vgpr
,
AddressSpace
::
global
,
InM
emory
DataOperation
::
atomic_add
>
({
0
,
0
,
0
,
0
,
0
,
0
},
{
e_thread_data_on_global
/
E1
,
e_thread_data_on_global
%
E1
,
0
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
%
B1
,
0
})
in_m
emory
_op
>
({
0
,
0
,
0
,
0
,
0
,
0
},
{
e_thread_data_on_global
/
E1
,
e_thread_data_on_global
%
E1
,
0
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
%
B1
,
0
})
.
Run
(
p_in_thread
,
p_in_global
);
}
}
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
e38ee30a
...
...
@@ -23,10 +23,10 @@ int main(int argc, char* argv[])
{
using
namespace
launcher
;
#if
1
#if
0
// 3x3 filter, 2x2 stride, 35x35 input
constexpr index_t N = 128;
constexpr
index_t
C
=
1
28
;
constexpr index_t C = 1
024
;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 1024;
...
...
@@ -59,7 +59,7 @@ int main(int argc, char* argv[])
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
...
...
@@ -83,13 +83,13 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 1x1 filter, 7x7 image
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
...
...
@@ -158,13 +158,13 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
#elif
1
#elif
0
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
...
...
@@ -246,7 +246,7 @@ int main(int argc, char* argv[])
#endif
}
#if
1
#if
0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif
1
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
...
...
@@ -257,17 +257,17 @@ int main(int argc, char* argv[])
#else
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif
(
in_nchw_desc
,
in_nchw_device
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
(
in_nchw_desc
,
in_nchw_device
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
if
(
do_verification
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment