Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f5e64f10
Commit
f5e64f10
authored
Dec 22, 2021
by
Chao Liu
Browse files
clean up
parent
8767acb2
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
274 additions
and
261 deletions
+274
-261
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp
+184
-159
device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp
...l_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp
+17
-17
device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp
...ffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp
+17
-17
device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp
...d_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp
+4
-4
device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
..._conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
+9
-8
device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
..._fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
+33
-33
example/6_conv2d_fwd_xdl_c_shuffle_bias_relu_add/conv2d_fwd_xdl_c_shuffle_bias_relu_add.cpp
..._bias_relu_add/conv2d_fwd_xdl_c_shuffle_bias_relu_add.cpp
+5
-5
example/7_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add/conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add.cpp
...mic_add/conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add.cpp
+5
-18
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp
View file @
f5e64f10
...
...
@@ -17,9 +17,9 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
typename
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
typename
C1GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
typename
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
typename
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
typename
C1GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
...
...
@@ -37,12 +37,12 @@ __global__ void
const
FloatC
*
__restrict__
p_c1_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
const
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
const
C1GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
const
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
C1GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
...
...
@@ -59,9 +59,9 @@ __global__ void
p_shared
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
a_element_op
,
b_element_op
,
c_element_op
,
...
...
@@ -88,8 +88,8 @@ template <
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
K1Value
,
index_t
M
Repeat
,
index_t
N
Repeat
,
index_t
M
XdlPerWave
,
index_t
N
XdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -106,9 +106,9 @@ template <
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
index_t
CShuffleM
Repeat
PerShuffle
,
index_t
CShuffleN
Repeat
PerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
index_t
CShuffleM
XdlPerWave
PerShuffle
,
index_t
CShuffleN
XdlPerWave
PerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{
...
...
@@ -124,8 +124,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
// TODO: need to calculate LDS usage for C shuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
...
...
@@ -144,6 +143,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
}
}();
return
a_block_desc_k0_m_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
...
...
@@ -159,14 +165,55 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
}
}();
return
b_block_desc_k0_n_k1
;
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
Number
<
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
Number
<
NWave
*
NPerXdl
>
{}));
return
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
;
}
// TODO: need to calculate LDS usage for C shuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
constexpr
auto
b_block_space_size
_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
=
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
();
constexpr
auto
c_block_size
=
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatC
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
...
@@ -180,8 +227,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXdl
*
M
Repeat
)
==
0
)
&&
(
NPerBlock
%
(
N
Repeat
*
NPerXdl
))
==
0
,
static_assert
((
MPerBlock
%
(
MPerXdl
*
M
XdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
N
XdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
...
...
@@ -230,7 +277,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
template
<
typename
CGridDesc_M_N_
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
const
CGridDesc_M_N_
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
...
...
@@ -239,20 +286,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
constexpr
index_t
MWave
=
MPerBlock
/
(
M
Repeat
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
Repeat
*
NPerXdl
);
constexpr
index_t
MWave
=
MPerBlock
/
(
M
XdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
XdlPerWave
*
NPerXdl
);
const
auto
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
=
const
auto
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
M
Repeat
>
{},
Number
<
MWave
*
MPerXdl
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
N
Repeat
>
{},
Number
<
NWave
*
NPerXdl
>
{}))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
M
XdlPerWave
>
{},
Number
<
MWave
*
MPerXdl
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
N
XdlPerWave
>
{},
Number
<
NWave
*
NPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
;
return
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
...
...
@@ -290,19 +337,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
using
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
=
using
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
CGridDesc_M_N
{}))
>
;
using
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
=
using
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
C0GridDesc_M_N
{}))
>
;
using
C1GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
=
using
C1GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
C1GridDesc_M_N
{}))
>
;
using
Block2CTileMap
=
remove_cvref_t
<
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
...
...
@@ -317,12 +364,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
&
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
const
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
&
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
const
C1GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
&
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
const
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
&
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
&
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
C1GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
&
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
...
...
@@ -334,15 +381,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
.
GetElementSpaceSize
());
auto
c0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c0_grid
,
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
.
GetElementSpaceSize
());
auto
c1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c1_grid
,
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
.
GetElementSpaceSize
());
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
...
...
@@ -362,34 +409,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
...
...
@@ -467,21 +490,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
decltype
(
b_block_desc_k0_n_k1
),
MPerXdl
,
NPerXdl
,
M
Repeat
,
N
Repeat
,
M
XdlPerWave
,
N
XdlPerWave
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size
_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
...
@@ -535,12 +558,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// shuffle C and write out
{
static_assert
(
M
Repeat
%
CShuffleM
Repeat
PerShuffle
==
0
&&
N
Repeat
%
CShuffleN
Repeat
PerShuffle
==
0
,
static_assert
(
M
XdlPerWave
%
CShuffleM
XdlPerWave
PerShuffle
==
0
&&
N
XdlPerWave
%
CShuffleN
XdlPerWave
PerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
M
Repeat
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
Repeat
*
NPerXdl
);
constexpr
index_t
MWave
=
MPerBlock
/
(
M
XdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
XdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
...
...
@@ -560,31 +583,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
>
{},
Number
<
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
>
{},
Number
<
NWave
*
NPerXdl
>
{}));
constexpr
auto
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
=
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared
),
c_block_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
c_block_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_pass_through_transform
(
Number
<
CShuffleMRepeatPerShuffle
>
{}),
// M0 (MRepeat) per shuffle
make_unmerge_transform
(
make_tuple
(
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform
(
I0
),
// freeze nblock
make_pass_through_transform
(
Number
<
CShuffleNRepeatPerShuffle
>
{}),
// N0 (NRepeat) per shuffle
make_unmerge_transform
(
make_tuple
(
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_pass_through_transform
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{}),
// M0 (MXdlPerWave) per shuffle
make_unmerge_transform
(
make_tuple
(
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform
(
I0
),
// freeze nblock
make_pass_through_transform
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{}),
// N0 (NXdlPerWave) per shuffle
make_unmerge_transform
(
make_tuple
(
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -635,8 +654,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleM
Repeat
PerShuffle
,
CShuffleN
Repeat
PerShuffle
,
Sequence
<
CShuffleM
XdlPerWave
PerShuffle
,
CShuffleN
XdlPerWave
PerShuffle
,
I1
,
I1
,
M2
,
...
...
@@ -665,21 +684,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleM
Repeat
PerShuffle
,
CShuffleM
XdlPerWave
PerShuffle
,
MWave
*
MPerXdl
,
1
,
CShuffleN
Repeat
PerShuffle
,
CShuffleN
XdlPerWave
PerShuffle
,
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename ThreadClusterArrangeOrder,
FloatC
,
// typename Src0Data,
FloatC
,
// typename Src1Data,
FloatC
,
// typename Src2Data,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
decltype
(
c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
decltype
(
c1_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
decltype
(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
5
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
// index_t ScalarPerVector,
...
...
@@ -687,36 +710,38 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
false
,
// bool ThreadTransferSrc1ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc2ResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_block_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
{
c_block_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
block_work_idx
[
I1
],
0
,
0
),
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
block_work_idx
[
I1
],
0
,
0
),
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
block_work_idx
[
I1
],
0
,
0
),
c_element_op
};
constexpr
auto
m
repeat
_forward_step
=
make_multi_index
(
0
,
CShuffleM
Repeat
PerShuffle
,
0
,
0
,
0
,
0
);
constexpr
auto
n
repeat
_forward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
CShuffleN
Repeat
PerShuffle
,
0
);
constexpr
auto
n
repeat
_backward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
-
CShuffleN
Repeat
PerShuffle
,
0
);
constexpr
auto
m
xdlperwave
_forward_step
=
make_multi_index
(
0
,
CShuffleM
XdlPerWave
PerShuffle
,
0
,
0
,
0
,
0
);
constexpr
auto
n
xdlperwave
_forward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
CShuffleN
XdlPerWave
PerShuffle
,
0
);
constexpr
auto
n
xdlperwave
_backward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
-
CShuffleN
XdlPerWave
PerShuffle
,
0
);
static_for
<
0
,
M
Repeat
,
CShuffleM
Repeat
PerShuffle
>
{}([
&
](
auto
m
repeat
_iter
)
{
constexpr
auto
m
repeat
=
mrepeat
_iter
;
static_for
<
0
,
M
XdlPerWave
,
CShuffleM
XdlPerWave
PerShuffle
>
{}([
&
](
auto
m
xdlperwave
_iter
)
{
constexpr
auto
m
xdlperwave
=
mxdlperwave
_iter
;
static_for
<
0
,
NRepeat
,
CShuffleNRepeatPerShuffle
>
{}([
&
](
auto
nrepeat_iter
)
{
constexpr
bool
nrepeat_forward_sweep
=
(
mrepeat
%
(
2
*
CShuffleMRepeatPerShuffle
)
==
0
);
static_for
<
0
,
NXdlPerWave
,
CShuffleNXdlPerWavePerShuffle
>
{}([
&
](
auto
nxdlperwave_iter
)
{
constexpr
bool
nxdlperwave_forward_sweep
=
(
mxdlperwave
%
(
2
*
CShuffleMXdlPerWavePerShuffle
)
==
0
);
constexpr
index_t
n
repeat
_value
=
n
repeat
_forward_sweep
?
n
repeat
_iter
:
(
N
Repeat
-
nrepeat
_iter
-
CShuffleN
Repeat
PerShuffle
);
constexpr
index_t
n
xdlperwave
_value
=
n
xdlperwave
_forward_sweep
?
n
xdlperwave
_iter
:
(
N
XdlPerWave
-
nxdlperwave
_iter
-
CShuffleN
XdlPerWave
PerShuffle
);
constexpr
auto
n
repeat
=
Number
<
n
repeat
_value
>
{};
constexpr
auto
n
xdlperwave
=
Number
<
n
xdlperwave
_value
>
{};
// make sure it's safe to do ds_write
block_sync_lds
();
...
...
@@ -724,7 +749,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
m
repeat
,
nrepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
m
xdlperwave
,
nxdlperwave
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_buf
);
...
...
@@ -734,61 +759,61 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// LDS to global
c_block_copy_lds_to_global
.
Run
(
c_block_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_block_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c_block_buf
,
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c0_grid_buf
,
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c1_grid_buf
,
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c_grid_buf
);
// move on n
repeat
dimension
if
constexpr
(
n
repeat
_forward_sweep
&&
(
n
repeat
<
NRepeat
-
CShuffleN
Repeat
PerShuffle
))
// move on n
xdlperwave
dimension
if
constexpr
(
n
xdlperwave
_forward_sweep
&&
(
n
xdlperwave
<
NXdlPerWave
-
CShuffleN
XdlPerWave
PerShuffle
))
{
c_block_copy_lds_to_global
.
MoveSrc1SliceWindow
(
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
n
repeat
_forward_step
);
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
xdlperwave
_forward_step
);
c_block_copy_lds_to_global
.
MoveSrc2SliceWindow
(
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
n
repeat
_forward_step
);
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
xdlperwave
_forward_step
);
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
n
repeat
_forward_step
);
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
xdlperwave
_forward_step
);
}
else
if
constexpr
((
!
n
repeat
_forward_sweep
)
&&
(
n
repeat
>
0
))
else
if
constexpr
((
!
n
xdlperwave
_forward_sweep
)
&&
(
n
xdlperwave
>
0
))
{
c_block_copy_lds_to_global
.
MoveSrc1SliceWindow
(
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
n
repeat
_backward_step
);
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
xdlperwave
_backward_step
);
c_block_copy_lds_to_global
.
MoveSrc2SliceWindow
(
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
n
repeat
_backward_step
);
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
xdlperwave
_backward_step
);
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
n
repeat
_backward_step
);
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
xdlperwave
_backward_step
);
}
});
// move on m
repeat
dimension
if
constexpr
(
m
repeat
<
MRepeat
-
CShuffleM
Repeat
PerShuffle
)
// move on m
xdlperwave
dimension
if
constexpr
(
m
xdlperwave
<
MXdlPerWave
-
CShuffleM
XdlPerWave
PerShuffle
)
{
c_block_copy_lds_to_global
.
MoveSrc1SliceWindow
(
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
m
repeat
_forward_step
);
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
m
xdlperwave
_forward_step
);
c_block_copy_lds_to_global
.
MoveSrc2SliceWindow
(
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
m
repeat
_forward_step
);
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
m
xdlperwave
_forward_step
);
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
m
repeat
_forward_step
);
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
m
xdlperwave
_forward_step
);
}
});
}
...
...
device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp
View file @
f5e64f10
...
...
@@ -19,23 +19,23 @@ using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using
device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
//##############################################################################################| InData| WeiData| OutData| AccData| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MRepeate| NRepeat
e| _MBlock_M
Repeat
_MWaveMPerXdl| ScalarPerVector|
//##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_N
Repeat
_NWaveNPerXdl| _NWaveNPerXdl|
//##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
//##############################################################################################| InData| WeiData| OutData| AccData| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle|
CShuffle|
CBlockTransferClusterLengths| CBlockTransfer|
//##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MXdlPerWave| NXdlPerWav
e| _MBlock_M
XdlPerWave
_MWaveMPerXdl| ScalarPerVector|
//##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle|
PerShuffle| _NBlock_N
XdlPerWave
_NWaveNPerXdl| _NWaveNPerXdl|
//##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddReluAdd
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
// clang-format on
>
;
...
...
device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp
View file @
f5e64f10
...
...
@@ -21,23 +21,23 @@ static constexpr auto InMemoryAtomicAdd = ck::InMemoryDataOperationEnum_t::Atomi
using
device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
//##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MRepeate| NRepeat
e| _MBlock_M
Repeat
_MWaveMPerXdl| ScalarPerVector|
//##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_N
Repeat
_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
16
>
,
2
>
//##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle|
CShuffle|
CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MXdlPerWave| NXdlPerWav
e| _MBlock_M
XdlPerWave
_MWaveMPerXdl| ScalarPerVector|
//##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle|
PerShuffle| _NBlock_N
XdlPerWave
_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
16
>
,
2
>
// clang-format on
>
;
...
...
device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp
View file @
f5e64f10
...
...
@@ -21,10 +21,10 @@ static constexpr auto MemorySet = ck::InMemoryDataOperationEnum_t::Set;
using
device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
//
| InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//
| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//
| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//
##########################################################################################
| InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//
##########################################################################################
| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//
##########################################################################################
| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//
##########################################################################################
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
...
...
device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
View file @
f5e64f10
...
...
@@ -17,12 +17,13 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough_v2
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
using
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
//##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
...
...
@@ -36,8 +37,8 @@ using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances = std::tuple<
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
// clang-format on
>
;
// clang-format on
>
;
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough_v2
>>&
device_conv_instances
)
...
...
device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
f5e64f10
...
...
@@ -49,9 +49,9 @@ template <
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleM
Repeat
PerShuffle
,
index_t
CShuffleN
Repeat
PerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
index_t
CShuffleM
XdlPerWave
PerShuffle
,
index_t
CShuffleN
XdlPerWave
PerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
...
@@ -269,9 +269,9 @@ struct
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleM
Repeat
PerShuffle
,
CShuffleN
Repeat
PerShuffle
,
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
CShuffleM
XdlPerWave
PerShuffle
,
CShuffleN
XdlPerWave
PerShuffle
,
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
CBlockTransferScalarPerVector_NWaveNPerXdl
>
;
// Argument
...
...
@@ -307,9 +307,9 @@ struct
c_grid_desc_m_n_
{},
c0_grid_desc_m_n_
{},
c1_grid_desc_m_n_
{},
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
{},
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
{},
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
{},
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
{},
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
{},
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
...
...
@@ -338,19 +338,19 @@ struct
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
=
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
c_grid_desc_m_n_
);
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
=
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
c0_grid_desc_m_n_
);
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
=
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
c1_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
...
...
@@ -369,14 +369,14 @@ struct
C0GridDesc_M_N
c0_grid_desc_m_n_
;
C1GridDesc_M_N
c1_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
;
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
;
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
;
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
;
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
;
C1GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
;
typename
GridwiseGemm
::
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -439,13 +439,13 @@ struct
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
remove_reference_t
<
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
remove_reference_t
<
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
C1GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
@@ -465,9 +465,9 @@ struct
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
...
...
@@ -483,13 +483,13 @@ struct
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
remove_reference_t
<
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
remove_reference_t
<
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
C1GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
@@ -509,9 +509,9 @@ struct
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c1_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
c1_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
...
...
example/6_conv2d_fwd_xdl_c_shuffle_bias_relu_add/conv2d_fwd_xdl_c_shuffle_bias_relu_add.cpp
View file @
f5e64f10
...
...
@@ -33,11 +33,11 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
// clang-format off
using
DeviceConvFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MRepeate| NRepeat
e| _MBlock_M
Repeat
_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_N
Repeat
_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| | |
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
;
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
CShuffle|
CShuffle|
CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MXdlPerWave| NXdlPerWav
e| _MBlock_M
XdlPerWave
_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerShuffle|
PerShuffle| _NBlock_N
XdlPerWave
_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
| |
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
;
// clang-format on
template
<
typename
TIn
,
...
...
example/7_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add/conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add.cpp
View file @
f5e64f10
...
...
@@ -36,11 +36,11 @@ static constexpr auto MemoryAtomicAdd = ck::InMemoryDataOperationEnum_t::AtomicA
using
DeviceConvFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// clang-format off
// | InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MRepeate| NRepeat
e| _MBlock_M
Repeat
_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_N
Repeat
_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| | |
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
MemoryAtomicAdd
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
;
// | InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
CShuffle|
CShuffle|
CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MXdlPerWave| NXdlPerWav
e| _MBlock_M
XdlPerWave
_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerShuffle|
PerShuffle| _NBlock_N
XdlPerWave
_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
| |
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
MemoryAtomicAdd
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
;
// clang-format on
template
<
typename
TIn
,
...
...
@@ -209,12 +209,6 @@ int main(int argc, char* argv[])
{
case
0
:
break
;
case
1
:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{});
out_n_k_ho_wo_host_result
.
GenerateTensorValue
(
GeneratorTensor_1
<
OutDataType
>
{});
bias_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
OutDataType
>
{});
break
;
case
2
:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
out_n_k_ho_wo_host_result
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
...
...
@@ -298,12 +292,5 @@ int main(int argc, char* argv[])
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
check_error
(
out_n_k_ho_wo_host_result
,
out_n_k_ho_wo_device_result
);
LogRangeAsType
<
float
>
(
std
::
cout
<<
"in : "
,
in_n_c_hi_wi
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei: "
,
wei_k_c_y_x
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out_host : "
,
out_n_k_ho_wo_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out_device: "
,
out_n_k_ho_wo_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment