Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8767acb2
Commit
8767acb2
authored
Dec 22, 2021
by
Chao Liu
Browse files
clean up
parent
35978330
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
221 additions
and
192 deletions
+221
-192
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp
+165
-141
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r1.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v6r1.hpp
+2
-0
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r2.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v6r2.hpp
+2
-0
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r3.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v6r3.hpp
+2
-0
device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp
...d_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp
+17
-17
device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
..._conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
+4
-5
device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
...nv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
+24
-24
example/5_conv2d_fwd_xdl_c_shuffle_bias_relu/conv2d_fwd_xdl_c_shuffle_bias_relu.cpp
..._shuffle_bias_relu/conv2d_fwd_xdl_c_shuffle_bias_relu.cpp
+5
-5
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp
View file @
8767acb2
...
...
@@ -17,8 +17,8 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
typename
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
typename
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
typename
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
...
...
@@ -35,10 +35,10 @@ __global__ void
const
FloatC
*
__restrict__
p_c0_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
const
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
const
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
...
...
@@ -54,8 +54,8 @@ __global__ void
p_shared
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
a_element_op
,
b_element_op
,
c_element_op
,
...
...
@@ -81,8 +81,8 @@ template <
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
K1Value
,
index_t
M
Repeat
,
index_t
N
Repeat
,
index_t
M
XdlPerWave
,
index_t
N
XdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -99,9 +99,9 @@ template <
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
index_t
CShuffleM
Repeat
PerShuffle
,
index_t
CShuffleN
Repeat
PerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
index_t
CShuffleM
XdlPerWave
PerShuffle
,
index_t
CShuffleN
XdlPerWave
PerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{
...
...
@@ -117,8 +117,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
// TODO: need to calculate LDS usage for C shuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
...
...
@@ -137,6 +136,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
}
}();
return
a_block_desc_k0_m_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
...
...
@@ -152,14 +158,55 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
}
}();
return
b_block_desc_k0_n_k1
;
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
Number
<
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
Number
<
NWave
*
NPerXdl
>
{}));
return
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
;
}
// TODO: need to calculate LDS usage for C shuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
constexpr
auto
b_block_space_size
_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
=
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
();
constexpr
auto
c_block_size
=
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatC
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
...
@@ -173,8 +220,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXdl
*
M
Repeat
)
==
0
)
&&
(
NPerBlock
%
(
N
Repeat
*
NPerXdl
))
==
0
,
static_assert
((
MPerBlock
%
(
MPerXdl
*
M
XdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
N
XdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
...
...
@@ -223,7 +270,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
template
<
typename
CGridDesc_M_N_
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
const
CGridDesc_M_N_
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
...
...
@@ -232,20 +279,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
constexpr
index_t
MWave
=
MPerBlock
/
(
M
Repeat
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
Repeat
*
NPerXdl
);
constexpr
index_t
MWave
=
MPerBlock
/
(
M
XdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
XdlPerWave
*
NPerXdl
);
const
auto
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
=
const
auto
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
M
Repeat
>
{},
Number
<
MWave
*
MPerXdl
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
N
Repeat
>
{},
Number
<
NWave
*
NPerXdl
>
{}))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
M
XdlPerWave
>
{},
Number
<
MWave
*
MPerXdl
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
N
XdlPerWave
>
{},
Number
<
NWave
*
NPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
;
return
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
...
...
@@ -283,14 +330,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
using
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
=
using
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
CGridDesc_M_N
{}))
>
;
using
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
=
using
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
C0GridDesc_M_N
{}))
>
;
using
Block2CTileMap
=
remove_cvref_t
<
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
...
...
@@ -304,10 +351,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
&
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
const
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
&
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
const
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
&
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
&
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
...
...
@@ -319,11 +366,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
.
GetElementSpaceSize
());
auto
c0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c0_grid
,
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
.
GetElementSpaceSize
());
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
...
...
@@ -343,34 +390,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
...
...
@@ -448,21 +471,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
decltype
(
b_block_desc_k0_n_k1
),
MPerXdl
,
NPerXdl
,
M
Repeat
,
N
Repeat
,
M
XdlPerWave
,
N
XdlPerWave
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size
_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
...
@@ -516,12 +539,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// shuffle C and write out
{
static_assert
(
M
Repeat
%
CShuffleM
Repeat
PerShuffle
==
0
&&
N
Repeat
%
CShuffleN
Repeat
PerShuffle
==
0
,
static_assert
(
M
XdlPerWave
%
CShuffleM
XdlPerWave
PerShuffle
==
0
&&
N
XdlPerWave
%
CShuffleN
XdlPerWave
PerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
M
Repeat
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
Repeat
*
NPerXdl
);
constexpr
index_t
MWave
=
MPerBlock
/
(
M
XdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
XdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
...
...
@@ -541,31 +564,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
>
{},
Number
<
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
>
{},
Number
<
NWave
*
NPerXdl
>
{}));
constexpr
auto
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
=
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared
),
c_block_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
c_block_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_pass_through_transform
(
Number
<
CShuffleMRepeatPerShuffle
>
{}),
// M0 (MRepeat) per shuffle
make_unmerge_transform
(
make_tuple
(
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform
(
I0
),
// freeze nblock
make_pass_through_transform
(
Number
<
CShuffleNRepeatPerShuffle
>
{}),
// N0 (NRepeat) per shuffle
make_unmerge_transform
(
make_tuple
(
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_pass_through_transform
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{}),
// M0 (MXdlPerWave) per shuffle
make_unmerge_transform
(
make_tuple
(
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform
(
I0
),
// freeze nblock
make_pass_through_transform
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{}),
// N0 (NXdlPerWave) per shuffle
make_unmerge_transform
(
make_tuple
(
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -616,8 +635,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleM
Repeat
PerShuffle
,
CShuffleN
Repeat
PerShuffle
,
Sequence
<
CShuffleM
XdlPerWave
PerShuffle
,
CShuffleN
XdlPerWave
PerShuffle
,
I1
,
I1
,
M2
,
...
...
@@ -646,53 +665,58 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleM
Repeat
PerShuffle
,
CShuffleM
XdlPerWave
PerShuffle
,
MWave
*
MPerXdl
,
1
,
CShuffleN
Repeat
PerShuffle
,
CShuffleN
XdlPerWave
PerShuffle
,
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename ThreadClusterArrangeOrder,
FloatC
,
// typename Src0Data,
FloatC
,
// typename Src1Data,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
decltype
(
c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
decltype
(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
5
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrc0ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc1ResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_block_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
{
c_block_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
block_work_idx
[
I1
],
0
,
0
),
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
block_work_idx
[
I1
],
0
,
0
),
c_element_op
};
constexpr
auto
m
repeat
_forward_step
=
make_multi_index
(
0
,
CShuffleM
Repeat
PerShuffle
,
0
,
0
,
0
,
0
);
constexpr
auto
n
repeat
_forward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
CShuffleN
Repeat
PerShuffle
,
0
);
constexpr
auto
n
repeat
_backward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
-
CShuffleN
Repeat
PerShuffle
,
0
);
constexpr
auto
m
xdlperwave
_forward_step
=
make_multi_index
(
0
,
CShuffleM
XdlPerWave
PerShuffle
,
0
,
0
,
0
,
0
);
constexpr
auto
n
xdlperwave
_forward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
CShuffleN
XdlPerWave
PerShuffle
,
0
);
constexpr
auto
n
xdlperwave
_backward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
-
CShuffleN
XdlPerWave
PerShuffle
,
0
);
static_for
<
0
,
M
Repeat
,
CShuffleM
Repeat
PerShuffle
>
{}([
&
](
auto
m
repeat
_iter
)
{
constexpr
auto
m
repeat
=
mrepeat
_iter
;
static_for
<
0
,
M
XdlPerWave
,
CShuffleM
XdlPerWave
PerShuffle
>
{}([
&
](
auto
m
xdlperwave
_iter
)
{
constexpr
auto
m
xdlperwave
=
mxdlperwave
_iter
;
static_for
<
0
,
NRepeat
,
CShuffleNRepeatPerShuffle
>
{}([
&
](
auto
nrepeat_iter
)
{
constexpr
bool
nrepeat_forward_sweep
=
(
mrepeat
%
(
2
*
CShuffleMRepeatPerShuffle
)
==
0
);
static_for
<
0
,
NXdlPerWave
,
CShuffleNXdlPerWavePerShuffle
>
{}([
&
](
auto
nxdlperwave_iter
)
{
constexpr
bool
nxdlperwave_forward_sweep
=
(
mxdlperwave
%
(
2
*
CShuffleMXdlPerWavePerShuffle
)
==
0
);
constexpr
index_t
n
repeat
_value
=
n
repeat
_forward_sweep
?
n
repeat
_iter
:
(
N
Repeat
-
nrepeat
_iter
-
CShuffleN
Repeat
PerShuffle
);
constexpr
index_t
n
xdlperwave
_value
=
n
xdlperwave
_forward_sweep
?
n
xdlperwave
_iter
:
(
N
XdlPerWave
-
nxdlperwave
_iter
-
CShuffleN
XdlPerWave
PerShuffle
);
constexpr
auto
n
repeat
=
Number
<
n
repeat
_value
>
{};
constexpr
auto
n
xdlperwave
=
Number
<
n
xdlperwave
_value
>
{};
// make sure it's safe to do ds_write
block_sync_lds
();
...
...
@@ -700,7 +724,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
m
repeat
,
nrepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
m
xdlperwave
,
nxdlperwave
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_buf
);
...
...
@@ -710,47 +734,47 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// LDS to global
c_block_copy_lds_to_global
.
Run
(
c_block_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_block_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c_block_buf
,
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c0_grid_buf
,
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c_grid_buf
);
// move on n
repeat
dimension
if
constexpr
(
n
repeat
_forward_sweep
&&
(
n
repeat
<
NRepeat
-
CShuffleN
Repeat
PerShuffle
))
// move on n
xdlperwave
dimension
if
constexpr
(
n
xdlperwave
_forward_sweep
&&
(
n
xdlperwave
<
NXdlPerWave
-
CShuffleN
XdlPerWave
PerShuffle
))
{
c_block_copy_lds_to_global
.
MoveSrc1SliceWindow
(
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
n
repeat
_forward_step
);
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
xdlperwave
_forward_step
);
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
n
repeat
_forward_step
);
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
xdlperwave
_forward_step
);
}
else
if
constexpr
((
!
n
repeat
_forward_sweep
)
&&
(
n
repeat
>
0
))
else
if
constexpr
((
!
n
xdlperwave
_forward_sweep
)
&&
(
n
xdlperwave
>
0
))
{
c_block_copy_lds_to_global
.
MoveSrc1SliceWindow
(
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
n
repeat
_backward_step
);
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
xdlperwave
_backward_step
);
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
n
repeat
_backward_step
);
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
xdlperwave
_backward_step
);
}
});
// move on m
repeat
dimension
if
constexpr
(
m
repeat
<
MRepeat
-
CShuffleM
Repeat
PerShuffle
)
// move on m
xdlperwave
dimension
if
constexpr
(
m
xdlperwave
<
MXdlPerWave
-
CShuffleM
XdlPerWave
PerShuffle
)
{
c_block_copy_lds_to_global
.
MoveSrc1SliceWindow
(
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
m
repeat
_forward_step
);
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
m
xdlperwave
_forward_step
);
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
m
repeat
_forward_step
);
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
m
xdlperwave
_forward_step
);
}
});
}
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r1.hpp
View file @
8767acb2
...
...
@@ -54,6 +54,8 @@ struct ThreadwiseTensorSliceTransfer_v6r1
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
)),
element_op_
(
element_op
)
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r2.hpp
View file @
8767acb2
...
...
@@ -62,6 +62,8 @@ struct ThreadwiseTensorSliceTransfer_v6r2
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
)),
element_op_
(
element_op
)
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
}
__device__
void
SetSrc0SliceOrigin
(
const
Src0Desc
&
src0_desc
,
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r3.hpp
View file @
8767acb2
...
...
@@ -70,6 +70,8 @@ struct ThreadwiseTensorSliceTransfer_v6r3
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
)),
element_op_
(
element_op
)
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
}
__device__
void
SetSrc0SliceOrigin
(
const
Src0Desc
&
src0_desc
,
...
...
device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp
View file @
8767acb2
...
...
@@ -21,23 +21,23 @@ static constexpr auto MemorySet = ck::InMemoryDataOperationEnum_t::Set;
using
device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
//
##########################################################################################
| InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//
##########################################################################################
| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MRepeate| NRepeat
e| _MBlock_M
Repeat
_MWaveMPerXdl| ScalarPerVector|
//
##########################################################################################
| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_N
Repeat
_NWaveNPerXdl| _NWaveNPerXdl|
//
##########################################################################################
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
//
| InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle|
CShuffle|
CBlockTransferClusterLengths| CBlockTransfer|
//
| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MXdlPerWave| NXdlPerWav
e| _MBlock_M
XdlPerWave
_MWaveMPerXdl| ScalarPerVector|
//
| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle|
PerShuffle| _NBlock_N
XdlPerWave
_NWaveNPerXdl| _NWaveNPerXdl|
//
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
MemorySet
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
// clang-format on
>
;
...
...
device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
View file @
8767acb2
...
...
@@ -17,9 +17,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough_v2
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
using
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
...
...
@@ -37,8 +36,8 @@ using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances =
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
// clang-format on
>
;
// clang-format on
>
;
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough_v2
>>&
device_conv_instances
)
...
...
device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
View file @
8767acb2
...
...
@@ -50,9 +50,9 @@ template <
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleM
Repeat
PerShuffle
,
index_t
CShuffleN
Repeat
PerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
index_t
CShuffleM
XdlPerWave
PerShuffle
,
index_t
CShuffleN
XdlPerWave
PerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvFwdBiasActivation
<
InElementwiseOperation
,
...
...
@@ -262,9 +262,9 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleM
Repeat
PerShuffle
,
CShuffleN
Repeat
PerShuffle
,
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
CShuffleM
XdlPerWave
PerShuffle
,
CShuffleN
XdlPerWave
PerShuffle
,
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
CBlockTransferScalarPerVector_NWaveNPerXdl
>
;
// Argument
...
...
@@ -297,8 +297,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c0_grid_desc_m_n_
{},
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
{},
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
{},
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
{},
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
...
...
@@ -326,14 +326,14 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
=
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
c_grid_desc_m_n_
);
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
=
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
c0_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
...
...
@@ -350,11 +350,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
CGridDesc_M_N
c_grid_desc_m_n_
;
C0GridDesc_M_N
c0_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
;
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
;
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
;
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
;
typename
GridwiseGemm
::
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -414,10 +414,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
remove_reference_t
<
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
@@ -436,8 +436,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
arg
.
p_c0_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
...
...
@@ -453,10 +453,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
remove_reference_t
<
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
C0GridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
@@ -475,8 +475,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
arg
.
p_c0_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c0_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
c0_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
...
...
example/5_conv2d_fwd_xdl_c_shuffle_bias_relu/conv2d_fwd_xdl_c_shuffle_bias_relu.cpp
View file @
8767acb2
...
...
@@ -36,11 +36,11 @@ static constexpr auto MemorySet = ck::InMemoryDataOperationEnum_t::Set;
using
DeviceConvFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// clang-format off
// | InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MRepeate| NRepeat
e| _MBlock_M
Repeat
_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_N
Repeat
_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| | |
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
MemorySet
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
;
// | InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
CShuffle|
CShuffle|
CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MXdlPerWave| NXdlPerWav
e| _MBlock_M
XdlPerWave
_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerShuffle|
PerShuffle| _NBlock_N
XdlPerWave
_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
| |
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
MemorySet
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
;
// clang-format on
template
<
typename
TIn
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment