Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
58b6996a
Commit
58b6996a
authored
Sep 16, 2021
by
ltqin
Browse files
some gridwise gemm write to C matrix
parent
971220d8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
289 additions
and
263 deletions
+289
-263
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+25
-0
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
+246
-245
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk.hpp
...forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk.hpp
+18
-18
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
58b6996a
...
@@ -131,6 +131,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -131,6 +131,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
}
__host__
__device__
static
constexpr
auto
GetCGM0N0M1N1M2M3M4N2ThreadDescriptor
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
I1
,
I1
,
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2BlockDescriptor
()
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2BlockDescriptor
()
{
{
constexpr
auto
c_m0_n0_m1_n1_m2_n2_block_desc
=
constexpr
auto
c_m0_n0_m1_n1_m2_n2_block_desc
=
...
@@ -144,6 +155,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -144,6 +155,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2Descriptor
(
c_m0_n0_m1_n1_m2_n2_block_desc
);
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2Descriptor
(
c_m0_n0_m1_n1_m2_n2_block_desc
);
}
}
__host__
__device__
static
constexpr
auto
GetCGM0N0M1N1M2M3M4N2BlockDescriptor
()
{
constexpr
auto
c_m0_n0_m1_n1_m2_n2_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2Descriptor
(
c_m0_n0_m1_n1_m2_n2_block_desc
);
}
template
<
typename
CMNGridDesc
>
template
<
typename
CMNGridDesc
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
View file @
58b6996a
...
@@ -142,6 +142,7 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -142,6 +142,7 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
...
@@ -388,9 +389,9 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -388,9 +389,9 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
constexpr
auto
c_mr_nr_blk_desc
=
constexpr
auto
c_mr_nr_blk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
constexpr
auto
c_
g_
m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2ThreadDescriptor
();
blockwise_gemm
.
GetC
G
M0N0M1N1M2M3M4N2ThreadDescriptor
();
constexpr
auto
CBlkSize
=
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
.
GetElementSpaceSize
();
constexpr
auto
CBlkSize
=
c_
g_
m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
.
GetElementSpaceSize
();
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
vector_type
<
FloatAcc
,
CBlkSize
>
,
vector_type
<
FloatAcc
,
CBlkSize
>
,
...
@@ -472,248 +473,248 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -472,248 +473,248 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
/*
// output: register to global memory
// output: register to global memory
{
{
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
constexpr
auto
c_
g_
m0_n0_m1_n1_m2_m3_m4_n2_block_desc
=
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
blockwise_gemm
.
GetC
G
M0N0M1N1M2M3M4N2BlockDescriptor
();
constexpr auto M2 =
constexpr
auto
M2
=
c_g_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I5
);
c
_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I
4
);
constexpr auto M3 =
constexpr
auto
M3
=
c_g
_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I
6
);
c
_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I
5
);
constexpr auto M4 =
constexpr
auto
M4
=
c_g
_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I
7
);
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
const auto c_thread_mtx_on_block =
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0,
I0);
const
index_t
m_thread_data_on_grid
=
m_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I0
];
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]
;
constexpr
auto
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
=
CGridStepHacks
{}
;
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
auto
c_thread_copy
=
= CGridStepHacks{};
ThreadwiseTensorSliceTransfer_v1r3
<
FloatC
,
FloatC
,
auto c
_thread_
copy =
decltype
(
c_g_m0_n0_m1_n1_m2_m3_m4_n2
_thread_
desc
),
ThreadwiseTensorSliceTransfer_v1r3<FloatC
,
decltype
(
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
)
,
FloatC
,
Sequence
<
I1
,
I1
,
I1
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc)
,
CThreadTransferSrcDstAccessOrder
,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)
,
CThreadTransferSrcDstVectorDim
,
Sequence<I1, I1, I1, I1
,
CThreadTransferDstScalarPerVector
,
M2, I1, M4, I1>, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim
,
CGlobalMemoryDataOperation
,
CThreadTransferDstScalarPerVector
,
1
,
CGlobalMemoryDataOperation,
true
>
{
1
,
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
true>{
make_multi_index
(
g_idx
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
0
,
make_multi_index(
0,
0
,
0,
0
,
0,
0
,
0
,
m_thread_data_on_grid
/
(
M3
*
M4
)
,
m_thread_data_on_grid
/
(M3 * M4),
m_thread_data_on_grid
%
(
M3
*
M4
)
/
M4
,
m_thread_data_on_grid %
(M3 * M4) /
m_thread_data_on_grid
%
M4
,
M4, m_thread_data_on_grid % M4,
n_thread_data_on_grid)};
n_thread_data_on_grid
)};
auto init_copy = [&](auto c_thread_idx_) {
auto
init_copy
=
[
&
](
auto
c_thread_idx_
)
{
constexpr auto blk_off =
constexpr
auto
blk_off
=
c_mr_nr_blk_desc
.
CalculateOffset
(
c_thread_idx_
);
c_mr_nr_blk_desc.CalculateOffset(c
_thread_
idx_);
c_thread_copy
.
Run
(
c_g_m0_n0_m1_n1_m2_m3_m4_n2
_thread_
desc
,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
)
,
make_tuple(I0, I0, I0, I0, I0, I0, I0
,
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>()
,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>()
,
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_grid_buf
,
c
_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_g
_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
return c_thread_idx_;
return
c_thread_idx_
;
};
};
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
/*
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(1, 0,
constexpr auto mrepeat_step_plus = make_multi_index(1, 0,
0, 0, 0, 0, 0, 0);
0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
mrepeat_step_plus);
constexpr auto blk_off =
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_mr_nr_blk_desc.CalculateOffset(c
_thread_
idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2
_thread_
desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0)
,
make_tuple(I0, I0, I0, I0, I0, I0, I0
,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>()
,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>()
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
};
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1,
constexpr auto nrepeat_step_plus = make_multi_index(0, 1,
0, 0, 0, 0, 0, 0);
0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_plus);
nrepeat_step_plus);
constexpr auto blk_off =
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_mr_nr_blk_desc.CalculateOffset(c
_thread_
idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2
_thread_
desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0)
,
make_tuple(I0, I0, I0, I0, I0, I0, I0
,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>()
,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>()
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
};
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0,
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0,
0, 0, 0, 0, 0, 0);
0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
mrepeat_step_plus);
constexpr auto blk_off =
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_mr_nr_blk_desc.CalculateOffset(c
_thread_
idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2
_thread_
desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0)
,
make_tuple(I0, I0, I0, I0, I0, I0, I0
,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>()
,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>()
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
};
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1,
constexpr auto nrepeat_step_minus = make_multi_index(0, -1,
0, 0, 0, 0, 0, 0);
0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_minus);
nrepeat_step_minus);
constexpr auto blk_off =
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_mr_nr_blk_desc.CalculateOffset(c
_thread_
idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2
_thread_
desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0)
,
make_tuple(I0, I0, I0, I0, I0, I0, I0
,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>()
,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>()
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
};
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4
&& NRepeat == 2) or
&& NRepeat == 2) or
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
(MRepeat == 2
(MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or
&& NRepeat == 1) or
(MRepeat == 1 && NRepeat ==
2) or (MRepeat == 1 &&
(MRepeat == 1 && NRepeat ==
1),
NRepeat == 1),
"wrong");
"wrong");
if constexpr(MRepeat == 4 && NRepeat == 4)
if constexpr(MRepeat == 4 && NRepeat == 4)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
if constexpr(CAccessOrderMRepeatNRepeat)
{
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3));
nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1));
nrepeat_plus_copy(make_tuple(I2, I1));
nrepeat_plus_copy(make_tuple(I2, I2));
nrepeat_plus_copy(make_tuple(I2, I2));
nrepeat_plus_copy(make_tuple(I2, I3));
nrepeat_plus_copy(make_tuple(I2, I3));
mrepeat_plus_copy(make_tuple(I3, I3));
mrepeat_plus_copy(make_tuple(I3, I3));
nrepeat_minus_copy(make_tuple(I3, I2));
nrepeat_minus_copy(make_tuple(I3, I2));
nrepeat_minus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0));
nrepeat_minus_copy(make_tuple(I3, I0));
}
}
else
else
{
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1));
nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
mrepeat_plus_copy(make_tuple(I2, I2));
mrepeat_plus_copy(make_tuple(I2, I2));
mrepeat_plus_copy(make_tuple(I3, I2));
mrepeat_plus_copy(make_tuple(I3, I2));
nrepeat_plus_copy(make_tuple(I3, I3));
nrepeat_plus_copy(make_tuple(I3, I3));
mrepeat_minus_copy(make_tuple(I2, I3));
mrepeat_minus_copy(make_tuple(I2, I3));
mrepeat_minus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
}
}
}
}
else if constexpr(MRepeat == 4 && NRepeat == 2)
else if constexpr(MRepeat == 4 && NRepeat == 2)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
if constexpr(CAccessOrderMRepeatNRepeat)
{
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1));
nrepeat_plus_copy(make_tuple(I2, I1));
mrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_plus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0));
nrepeat_minus_copy(make_tuple(I3, I0));
}
}
else
else
{
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1));
nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
}
}
}
}
else if constexpr(MRepeat == 2 && NRepeat == 4)
else if constexpr(MRepeat == 2 && NRepeat == 4)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
if constexpr(CAccessOrderMRepeatNRepeat)
{
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3));
nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
nrepeat_minus_copy(make_tuple(I1, I0));
}
}
else
else
{
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
nrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_plus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
}
}
}
}
else if constexpr(MRepeat == 2 && NRepeat == 2)
else if constexpr(MRepeat == 2 && NRepeat == 2)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
if constexpr(CAccessOrderMRepeatNRepeat)
{
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
nrepeat_minus_copy(make_tuple(I1, I0));
}
}
else
else
{
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
}
}
}
}
else if constexpr(MRepeat == 2 && NRepeat == 1)
else if constexpr(MRepeat == 2 && NRepeat == 1)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
}
}
else if constexpr(MRepeat == 1 && NRepeat == 2)
else if constexpr(MRepeat == 1 && NRepeat == 2)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I1));
}
}
else if constexpr(MRepeat == 1 && NRepeat == 1)
else if constexpr(MRepeat == 1 && NRepeat == 1)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
}
}
}
*/
*/
}
}
}
};
// namespace ck
};
// namespace ck
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk.hpp
View file @
58b6996a
...
@@ -255,24 +255,24 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk(
...
@@ -255,24 +255,24 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
out_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
constexpr
auto
out_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 7+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 7+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 8+: N2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 8+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 7-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 7-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 8-: N2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 8-: N2
constexpr
auto
in_gemmg_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
constexpr
auto
in_gemmg_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment