Commit fd4ff3a7 authored by aska-0096's avatar aska-0096
Browse files

(4/5) grouped conv pass

parent 12a4ea69
......@@ -52,32 +52,32 @@ using DeviceConvFwdInstance =
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, // Prefetch stage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // K1
128, // BlockSize
64, // MPerBlock
64, // NPerBlock
64, // KPerBlock
4, // K1
16, // MPerWMMA
16, // NPerWMMA
4, // MRepeat
2, // NRepeat
S<4, 8, 8>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
1, // NRepeat
S<4, 8, 4>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
1, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<4, 8, 8>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<4, 8, 4>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
1, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
4,
2,
S<1, 32, 1, 8>,
1,
1,
S<1, 16, 1, 8>,
8>;
template <ck::index_t NDimSpatial>
......
......@@ -163,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
......@@ -232,18 +233,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
else
{
constexpr auto A_KRow = WmmaK / K1;
const auto A_KWmma = K / WmmaK;
constexpr auto A_KRow = 2;
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)),
make_tuple(make_unmerge_transform(make_tuple(
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
}
......@@ -275,18 +279,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
else
{
constexpr auto B_KRow = WmmaK / K1;
const auto B_KWmma = K / WmmaK;
constexpr auto B_KRow = 2;
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor(
wei_gemmn_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_tuple(make_unmerge_transform(make_tuple(
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
}
......@@ -556,7 +563,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
else
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5);
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
}
}();
......@@ -884,9 +891,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", "
<< MPerWmma << ", "
<< NPerWmma << ", "
<< MRepeat << ", "
<< NRepeat
<< ">"
<< " AEnableLds: "
<< AEnableLds << ", "
<< "BEnableLds: "
<< BEnableLds << ", "
<< "ABlockTransferSrcScalarPerVector: "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector
<< ">";
<< "BBlockTransferSrcScalarPerVector: "
<< BBlockTransferSrcScalarPerVector;
// clang-format on
return str.str();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment