Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9750de73
"...composable_kernel_rocm.git" did not exist on "f95267f166927bee1d806cefbdc142b2e35f640f"
Commit
9750de73
authored
Jan 13, 2020
by
Chao Liu
Browse files
adding bwd data v3r1
parent
ef2664fb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
194 additions
and
142 deletions
+194
-142
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+148
-135
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+42
-3
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+2
-2
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+2
-2
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
View file @
9750de73
...
...
@@ -205,142 +205,155 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
#if 1 // debug
// get a series of GEMMs
auto
f_get_gemm
=
[
&
](
auto
ytilda_
,
auto
xtilda_
)
{
constexpr
index_t
ytilda
=
decltype
(
ytilda_
){};
constexpr
index_t
xtilda
=
decltype
(
xtilda_
){};
constexpr
index_t
Ydotnonzero
=
(
ytilda
+
1
)
*
Ydot
<=
Y
?
Ydot
:
Y
%
Ydot
;
constexpr
index_t
Xdotnonzero
=
(
xtilda
+
1
)
*
Xdot
<=
X
?
Xdot
:
X
%
Xdot
;
// A matrix
constexpr
auto
wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
Ydotnonzero
,
Xdot
-
Xdotnonzero
>>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
Ydotnonzero
,
Xdotnonzero
>>
{},
Merge
<
Sequence
<
C
,
1
,
1
>>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B matrix
constexpr
auto
out_n_k_ydotnonzero_htildatrim_xdotnonzero_wtildatrim_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
Ydotnonzero
,
Xdot
-
Xdotnonzero
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydotnonzero_htildatrim_xdotnonzero_wtildatrim_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
Ydotnonzero
,
Xdotnonzero
>>
{},
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// C matrix
constexpr
auto
in_n_c_1_htildatrim_1_wtildatrim_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_1_htildatrim_1_wtildatrim_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
1
,
1
>>
{},
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
none
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
return
gridwise_gemm
;
};
// GEMMs
static_for
<
0
,
Ytilda
,
1
>
{}([
&
](
auto
ytilda_
)
{
static_for
<
0
,
Xtilda
,
1
>
{}([
&
](
auto
xtilda_
)
{
#else
static_for
<
1
,
2
,
1
>
{}([
&
](
auto
ytilda_
)
{
static_for
<
1
,
2
,
1
>
{}([
&
](
auto
xtilda_
)
{
#endif
constexpr
index_t
ytilda
=
decltype
(
ytilda_
){};
constexpr
index_t
xtilda
=
decltype
(
xtilda_
){};
constexpr
index_t
Ydotnonzero
=
(
ytilda
+
1
)
*
Ydot
<=
Y
?
Ydot
:
Y
%
Ydot
;
constexpr
index_t
Xdotnonzero
=
(
xtilda
+
1
)
*
Xdot
<=
X
?
Xdot
:
X
%
Xdot
;
// A matrix
constexpr
auto
wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
Ydotnonzero
,
Xdot
-
Xdotnonzero
>>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
Ydotnonzero
,
Xdotnonzero
>>
{},
Merge
<
Sequence
<
C
,
1
,
1
>>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B matrix
constexpr
auto
out_n_k_ydotnonzero_htildatrim_xdotnonzero_wtildatrim_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
Ydotnonzero
,
Xdot
-
Xdotnonzero
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydotnonzero_htildatrim_xdotnonzero_wtildatrim_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
Ydotnonzero
,
Xdotnonzero
>>
{},
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// C matrix
constexpr
auto
in_n_c_1_htildatrim_1_wtildatrim_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_1_htildatrim_1_wtildatrim_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
1
,
1
>>
{},
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
none
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
index_t
shared_mem_size
=
0
;
static_for
<
0
,
Ytilda
,
1
>
{}([
&
](
auto
ytilda
)
{
static_for
<
0
,
Xtilda
,
1
>
{}([
&
](
auto
xtilda
)
{
auto
gemm
=
f_get_gemm
(
ytilda
,
xtilda
);
shared_mem_size
=
math
::
max
(
shared_mem_size
,
gemm
.
GetSharedMemorySize
());
});
});
__shared__
Float
p_shared_float
[
shared_mem_size
/
sizeof
(
Float
)];
// GEMMs
static_for
<
0
,
Ytilda
,
1
>
{}([
&
](
auto
ytilda
)
{
static_for
<
0
,
Xtilda
,
1
>
{}([
&
](
auto
xtilda
)
{
auto
gemm
=
f_get_gemm
(
ytilda
,
xtilda
);
gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
,
p_shared_float
);
});
});
}
...
...
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
9750de73
...
...
@@ -50,9 +50,37 @@ template <index_t GridSize,
index_t
CThreadCopyDstDataPerWrite
>
struct
GridwiseGemmTransposedANormalBNormalC_v1
{
__host__
__device__
static
constexpr
index_t
GetSharedMemorySize
()
{
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
BBlockCopyDstDataPerWrite_N
,
ThreadGemmDataPerReadM
,
ThreadGemmDataPerReadN
);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
return
2
*
(
a_block_space
+
b_block_space
)
*
sizeof
(
Float
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
)
const
Float
*
__restrict__
p_c_global
,
void
*
__restrict__
p_shared
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
...
...
@@ -184,8 +212,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
__shared__
Float
p_a_block_double
[
2
*
a_block_space
]
;
__shared__
Float
p_b_block_double
[
2
*
b
_block_space
]
;
Float
*
p_a_block_double
=
reinterpret_cast
<
Float
*>
(
p_shared
)
;
Float
*
p_b_block_double
=
p_a_block_double
+
2
*
a
_block_space
;
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
...
...
@@ -329,6 +357,17 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
.
Run
(
p_c_thread
,
p_c_global
);
}
}
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
)
const
{
constexpr
index_t
shared_mem_size
=
GetSharedMemorySize
();
__shared__
Float
p_shared_float
[
shared_mem_size
/
sizeof
(
Float
)];
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_shared_float
);
}
};
}
// namespace ck
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
9750de73
...
...
@@ -187,7 +187,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
1
#elif
0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
...
...
@@ -249,7 +249,7 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif
0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif
1
#elif
0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#else
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
...
...
driver/src/conv_driver.cpp
View file @
9750de73
...
...
@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr
index_t
N
=
128
;
...
...
@@ -327,7 +327,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
#elif
0
#elif
1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment