Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
35978330
Commit
35978330
authored
Dec 22, 2021
by
Chao Liu
Browse files
clean up
parent
d3bd5922
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
191 additions
and
167 deletions
+191
-167
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
+146
-123
device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
..._conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
+17
-17
device_operation/include/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...nclude/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+23
-22
example/4_conv2d_fwd_xdl_c_shuffle/conv2d_fwd_xdl_c_shuffle.cpp
...e/4_conv2d_fwd_xdl_c_shuffle/conv2d_fwd_xdl_c_shuffle.cpp
+5
-5
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
View file @
35978330
...
@@ -17,7 +17,7 @@ template <typename GridwiseGemm,
...
@@ -17,7 +17,7 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
typename
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
...
@@ -33,8 +33,8 @@ __global__ void
...
@@ -33,8 +33,8 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
const
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
...
@@ -49,7 +49,7 @@ __global__ void
...
@@ -49,7 +49,7 @@ __global__ void
p_shared
,
p_shared
,
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
...
@@ -74,8 +74,8 @@ template <
...
@@ -74,8 +74,8 @@ template <
index_t
MPerXdl
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
NPerXdl
,
index_t
K1Value
,
index_t
K1Value
,
index_t
M
Repeat
,
index_t
M
XdlPerWave
,
index_t
N
Repeat
,
index_t
N
XdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
@@ -92,9 +92,9 @@ template <
...
@@ -92,9 +92,9 @@ template <
index_t
BBlockTransferDstScalarPerVector_K1
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
bool
BBlockLdsExtraN
,
index_t
CShuffleM
Repeat
PerShuffle
,
index_t
CShuffleM
XdlPerWave
PerShuffle
,
index_t
CShuffleN
Repeat
PerShuffle
,
index_t
CShuffleN
XdlPerWave
PerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
typename
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{
{
...
@@ -110,8 +110,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -110,8 +110,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
// TODO: need to calculate LDS usage for C shuffle
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -130,6 +129,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -130,6 +129,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
}
}();
}();
return
a_block_desc_k0_m_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
if
constexpr
(
BBlockLdsExtraN
)
...
@@ -145,14 +151,55 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -145,14 +151,55 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
}
}();
}();
return
b_block_desc_k0_n_k1
;
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
Number
<
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
Number
<
NWave
*
NPerXdl
>
{}));
return
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
;
}
// TODO: need to calculate LDS usage for C shuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
constexpr
auto
b_block_space_size
_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
=
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
();
constexpr
auto
c_block_size
=
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatC
));
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
@@ -166,8 +213,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -166,8 +213,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXdl
*
M
Repeat
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
M
XdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
N
Repeat
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
N
XdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
...
@@ -215,7 +262,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -215,7 +262,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
...
@@ -224,20 +271,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -224,20 +271,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
constexpr
index_t
MWave
=
MPerBlock
/
(
M
Repeat
*
MPerXdl
);
constexpr
index_t
MWave
=
MPerBlock
/
(
M
XdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
Repeat
*
NPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
XdlPerWave
*
NPerXdl
);
const
auto
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
=
const
auto
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
c_grid_desc_m_n
,
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
make_tuple
(
MBlock
,
Number
<
M
Repeat
>
{},
Number
<
MWave
*
MPerXdl
>
{})),
MBlock
,
Number
<
M
XdlPerWave
>
{},
Number
<
MWave
*
MPerXdl
>
{})),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
make_tuple
(
NBlock
,
Number
<
N
Repeat
>
{},
Number
<
NWave
*
NPerXdl
>
{}))),
NBlock
,
Number
<
N
XdlPerWave
>
{},
Number
<
NWave
*
NPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
;
return
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
;
}
}
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
...
@@ -275,9 +322,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -275,9 +322,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
}
using
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
=
using
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
CGridDesc_M_N
{}))
>
;
CGridDesc_M_N
{}))
>
;
using
Block2CTileMap
=
remove_cvref_t
<
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
using
Block2CTileMap
=
remove_cvref_t
<
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
...
@@ -290,8 +337,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -290,8 +337,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
&
const
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
&
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
...
@@ -303,7 +350,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -303,7 +350,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
p_c_grid
,
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
.
GetElementSpaceSize
());
.
GetElementSpaceSize
());
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
...
@@ -323,34 +370,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -323,34 +370,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
...
@@ -428,21 +451,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -428,21 +451,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerXdl
,
MPerXdl
,
NPerXdl
,
NPerXdl
,
M
Repeat
,
M
XdlPerWave
,
N
Repeat
,
N
XdlPerWave
,
K1
>
{};
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size
_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
@@ -496,12 +519,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -496,12 +519,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// shuffle C and write out
// shuffle C and write out
{
{
static_assert
(
M
Repeat
%
CShuffleM
Repeat
PerShuffle
==
0
&&
static_assert
(
M
XdlPerWave
%
CShuffleM
XdlPerWave
PerShuffle
==
0
&&
N
Repeat
%
CShuffleN
Repeat
PerShuffle
==
0
,
N
XdlPerWave
%
CShuffleN
XdlPerWave
PerShuffle
==
0
,
"wrong!"
);
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
M
Repeat
*
MPerXdl
);
constexpr
index_t
MWave
=
MPerBlock
/
(
M
XdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
Repeat
*
NPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
N
XdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
...
@@ -521,31 +544,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -521,31 +544,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
=
constexpr
auto
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
();
Number
<
CShuffleMRepeatPerShuffle
>
{},
Number
<
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
>
{},
Number
<
NWave
*
NPerXdl
>
{}));
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared
),
static_cast
<
FloatC
*>
(
p_shared
),
c_block_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
c_block_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
.
GetElementSpaceSize
());
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_tuple
(
make_pass_through_transform
(
make_freeze_transform
(
I0
),
// freeze mblock
Number
<
CShuffleMRepeatPerShuffle
>
{}),
// M0 (MRepeat) per shuffle
make_pass_through_transform
(
make_unmerge_transform
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{}),
// M0 (MXdlPerWave) per shuffle
make_tuple
(
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_unmerge_transform
(
make_freeze_transform
(
I0
),
// freeze nblock
make_tuple
(
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_pass_through_transform
(
make_freeze_transform
(
I0
),
// freeze nblock
Number
<
CShuffleNRepeatPerShuffle
>
{}),
// N0 (NRepeat) per shuffle
make_pass_through_transform
(
make_unmerge_transform
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{}),
// N0 (NXdlPerWave) per shuffle
make_tuple
(
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_unmerge_transform
(
make_tuple
(
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -596,8 +615,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -596,8 +615,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleM
Repeat
PerShuffle
,
Sequence
<
CShuffleM
XdlPerWave
PerShuffle
,
CShuffleN
Repeat
PerShuffle
,
CShuffleN
XdlPerWave
PerShuffle
,
I1
,
I1
,
I1
,
I1
,
M2
,
M2
,
...
@@ -626,48 +645,52 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -626,48 +645,52 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
CElementwiseOperation
,
// ElementwiseOperation,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
CShuffleM
Repeat
PerShuffle
,
CShuffleM
XdlPerWave
PerShuffle
,
MWave
*
MPerXdl
,
MWave
*
MPerXdl
,
1
,
1
,
CShuffleN
Repeat
PerShuffle
,
CShuffleN
XdlPerWave
PerShuffle
,
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename ThreadClusterArrangeOrder,
FloatC
,
// typename SrcData,
FloatC
,
// typename SrcData,
FloatC
,
// typename DstData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
decltype
(
decltype
(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
5
,
// index_t VectorDim,
5
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
// index_t ScalarPerVector,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_block_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
{
c_block_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
block_work_idx
[
I1
],
0
,
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
block_work_idx
[
I1
],
0
,
0
),
c_element_op
};
c_element_op
};
constexpr
auto
m
repeat
_forward_step
=
constexpr
auto
m
xdlperwave
_forward_step
=
make_multi_index
(
0
,
CShuffleM
Repeat
PerShuffle
,
0
,
0
,
0
,
0
);
make_multi_index
(
0
,
CShuffleM
XdlPerWave
PerShuffle
,
0
,
0
,
0
,
0
);
constexpr
auto
n
repeat
_forward_step
=
constexpr
auto
n
xdlperwave
_forward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
CShuffleN
Repeat
PerShuffle
,
0
);
make_multi_index
(
0
,
0
,
0
,
0
,
CShuffleN
XdlPerWave
PerShuffle
,
0
);
constexpr
auto
n
repeat
_backward_step
=
constexpr
auto
n
xdlperwave
_backward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
-
CShuffleN
Repeat
PerShuffle
,
0
);
make_multi_index
(
0
,
0
,
0
,
0
,
-
CShuffleN
XdlPerWave
PerShuffle
,
0
);
static_for
<
0
,
M
Repeat
,
CShuffleM
Repeat
PerShuffle
>
{}([
&
](
auto
m
repeat
_iter
)
{
static_for
<
0
,
M
XdlPerWave
,
CShuffleM
XdlPerWave
PerShuffle
>
{}([
&
](
auto
m
xdlperwave
_iter
)
{
constexpr
auto
m
repeat
=
mrepeat
_iter
;
constexpr
auto
m
xdlperwave
=
mxdlperwave
_iter
;
static_for
<
0
,
NRepeat
,
CShuffleNRepeatPerShuffle
>
{}([
&
](
auto
nrepeat_iter
)
{
static_for
<
0
,
constexpr
bool
nrepeat_forward_sweep
=
NXdlPerWave
,
(
mrepeat
%
(
2
*
CShuffleMRepeatPerShuffle
)
==
0
);
CShuffleNXdlPerWavePerShuffle
>
{}([
&
](
auto
nxdlperwave_iter
)
{
constexpr
bool
nxdlperwave_forward_sweep
=
(
mxdlperwave
%
(
2
*
CShuffleMXdlPerWavePerShuffle
)
==
0
);
constexpr
index_t
n
repeat
_value
=
constexpr
index_t
n
xdlperwave
_value
=
n
repeat
_forward_sweep
n
xdlperwave
_forward_sweep
?
n
repeat
_iter
?
n
xdlperwave
_iter
:
(
N
Repeat
-
nrepeat
_iter
-
CShuffleN
Repeat
PerShuffle
);
:
(
N
XdlPerWave
-
nxdlperwave
_iter
-
CShuffleN
XdlPerWave
PerShuffle
);
constexpr
auto
n
repeat
=
Number
<
n
repeat
_value
>
{};
constexpr
auto
n
xdlperwave
=
Number
<
n
xdlperwave
_value
>
{};
// make sure it's safe to do ds_write
// make sure it's safe to do ds_write
block_sync_lds
();
block_sync_lds
();
...
@@ -675,7 +698,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -675,7 +698,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// VGPR to LDS
// VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
m
repeat
,
nrepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
m
xdlperwave
,
nxdlperwave
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_buf
);
c_block_buf
);
...
@@ -685,33 +708,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -685,33 +708,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// LDS to global
// LDS to global
c_block_copy_lds_to_global
.
Run
(
c_block_copy_lds_to_global
.
Run
(
c_block_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_block_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c_block_buf
,
c_block_buf
,
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
c_grid_buf
);
c_grid_buf
);
// move on n
repeat
dimension
// move on n
xdlperwave
dimension
if
constexpr
(
n
repeat
_forward_sweep
&&
if
constexpr
(
n
xdlperwave
_forward_sweep
&&
(
n
repeat
<
NRepeat
-
CShuffleN
Repeat
PerShuffle
))
(
n
xdlperwave
<
NXdlPerWave
-
CShuffleN
XdlPerWave
PerShuffle
))
{
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
repeat
_forward_step
);
n
xdlperwave
_forward_step
);
}
}
else
if
constexpr
((
!
n
repeat
_forward_sweep
)
&&
(
n
repeat
>
0
))
else
if
constexpr
((
!
n
xdlperwave
_forward_sweep
)
&&
(
n
xdlperwave
>
0
))
{
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
n
repeat
_backward_step
);
n
xdlperwave
_backward_step
);
}
}
});
});
// move on m
repeat
dimension
// move on m
xdlperwave
dimension
if
constexpr
(
m
repeat
<
MRepeat
-
CShuffleM
Repeat
PerShuffle
)
if
constexpr
(
m
xdlperwave
<
MXdlPerWave
-
CShuffleM
XdlPerWave
PerShuffle
)
{
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl
,
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl
,
m
repeat
_forward_step
);
m
xdlperwave
_forward_step
);
}
}
});
});
}
}
...
...
device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
View file @
35978330
...
@@ -20,23 +20,23 @@ using PassThrough_v2 = ck::tensor_operation::element_wise::PassThrough;
...
@@ -20,23 +20,23 @@ using PassThrough_v2 = ck::tensor_operation::element_wise::PassThrough;
using
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
=
using
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//
##########################################################################
| InData| WeiData| OutData| AccData|
A
|
B
|
C
| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//
| InData| WeiData| OutData| AccData|
In
|
Wei
|
Out
| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle|
CShuffle|
CBlockTransferClusterLengths| CBlockTransfer|
//
##########################################################################
| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MRepeate| NRepeat
e| _MBlock_M
Repeat
_MWaveMPerXdl| ScalarPerVector|
//
| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MXdlPerWave| NXdlPerWav
e| _MBlock_M
XdlPerWave
_MWaveMPerXdl| ScalarPerVector|
//
##########################################################################
| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_N
Repeat
_NWaveNPerXdl| _NWaveNPerXdl|
//
| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle|
PerShuffle| _NBlock_N
XdlPerWave
_NWaveNPerXdl| _NWaveNPerXdl|
//
##########################################################################
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| |
//
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| |
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
// clang-format on
// clang-format on
>
;
>
;
...
...
device_operation/include/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
35978330
...
@@ -48,9 +48,9 @@ template <
...
@@ -48,9 +48,9 @@ template <
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleM
Repeat
PerShuffle
,
index_t
CShuffleM
XdlPerWave
PerShuffle
,
index_t
CShuffleN
Repeat
PerShuffle
,
index_t
CShuffleN
XdlPerWave
PerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
typename
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
struct
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
...
@@ -249,9 +249,9 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -249,9 +249,9 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
BBlockLdsAddExtraN
,
CShuffleM
Repeat
PerShuffle
,
CShuffleM
XdlPerWave
PerShuffle
,
CShuffleN
Repeat
PerShuffle
,
CShuffleN
XdlPerWave
PerShuffle
,
CBlockTransferClusterLengths_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
,
CBlockTransferScalarPerVector_NWaveNPerXdl
>
;
CBlockTransferScalarPerVector_NWaveNPerXdl
>
;
// Argument
// Argument
...
@@ -281,7 +281,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -281,7 +281,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
a_grid_desc_k0_m_k1_
{},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
{},
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
{},
block_2_ctile_map_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
...
@@ -308,9 +308,9 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -308,9 +308,9 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
if
(
GridwiseGemm
::
CheckValidity
(
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
{
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
=
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
=
GridwiseGemm
::
GridwiseGemm
::
MakeCGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
(
c_grid_desc_m_n_
);
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
...
@@ -325,8 +325,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -325,8 +325,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
typename
GridwiseGemm
::
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
;
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
;
typename
GridwiseGemm
::
Block2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
...
@@ -355,23 +355,24 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -355,23 +355,24 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
std
::
cout
<<
"arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_{ "
<<
"arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_"
<<
arg
.
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
"nwavenperxdl_{ "
<<
arg
.
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.
GetLength
(
I0
)
.
GetLength
(
I0
)
<<
", "
<<
", "
<<
arg
.
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
<<
arg
.
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
.
GetLength
(
I1
)
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
arg
.
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
<<
arg
.
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
.
GetLength
(
I2
)
.
GetLength
(
I2
)
<<
", "
<<
", "
<<
arg
.
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
<<
arg
.
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
.
GetLength
(
I3
)
.
GetLength
(
I3
)
<<
", "
<<
", "
<<
arg
.
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
<<
arg
.
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
.
GetLength
(
I4
)
.
GetLength
(
I4
)
<<
", "
<<
", "
<<
arg
.
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
<<
arg
.
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
.
GetLength
(
I5
)
.
GetLength
(
I5
)
<<
"}"
<<
std
::
endl
;
<<
"}"
<<
std
::
endl
;
}
}
...
@@ -404,7 +405,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -404,7 +405,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
remove_reference_t
<
typename
GridwiseGemm
::
typename
GridwiseGemm
::
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
...
@@ -422,7 +423,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -422,7 +423,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
in_element_op_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
out_element_op_
,
...
@@ -438,7 +439,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -438,7 +439,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
remove_reference_t
<
typename
GridwiseGemm
::
typename
GridwiseGemm
::
CGridDescriptor_MBlock_M
Repeat
_MWaveMPerXdl_NBlock_N
Repeat
_NWaveNPerXdl
>
,
CGridDescriptor_MBlock_M
XdlPerWave
_MWaveMPerXdl_NBlock_N
XdlPerWave
_NWaveNPerXdl
>
,
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
...
@@ -456,7 +457,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -456,7 +457,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_m
repeat
_mwavemperxdl_nblock_n
repeat
_nwavenperxdl_
,
arg
.
c_grid_desc_mblock_m
xdlperwave
_mwavemperxdl_nblock_n
xdlperwave
_nwavenperxdl_
,
arg
.
in_element_op_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
out_element_op_
,
...
...
example/4_conv2d_fwd_xdl_c_shuffle/conv2d_fwd_xdl_c_shuffle.cpp
View file @
35978330
...
@@ -33,11 +33,11 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
...
@@ -33,11 +33,11 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
using
DeviceConvFwdInstance
=
ck
::
tensor_operation
::
device
::
using
DeviceConvFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// clang-format off
// clang-format off
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
CShuffle|
CShuffle|
CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MRepeate| NRepeat
e| _MBlock_M
Repeat
_MWaveMPerXdl| ScalarPerVector|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MXdlPerWave| NXdlPerWav
e| _MBlock_M
XdlPerWave
_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_N
Repeat
_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerShuffle|
PerShuffle| _NBlock_N
XdlPerWave
_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
| | |
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
| |
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
;
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
;
// clang-format on
// clang-format on
template
<
typename
TIn
,
template
<
typename
TIn
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment