Commit fd4ff3a7 authored by aska-0096's avatar aska-0096
Browse files

(4/5) grouped conv pass

parent 12a4ea69
...@@ -52,32 +52,32 @@ using DeviceConvFwdInstance = ...@@ -52,32 +52,32 @@ using DeviceConvFwdInstance =
ConvSpec, // ConvForwardSpecialization ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization GemmSpec, // GemmSpecialization
1, // Prefetch stage 1, // Prefetch stage
256, // BlockSize 128, // BlockSize
128, // MPerBlock 64, // MPerBlock
128, // NPerBlock 64, // NPerBlock
32, // KPerBlock 64, // KPerBlock
8, // K1 4, // K1
16, // MPerWMMA 16, // MPerWMMA
16, // NPerWMMA 16, // NPerWMMA
4, // MRepeat 4, // MRepeat
2, // NRepeat 1, // NRepeat
S<4, 8, 8>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<4, 8, 4>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
1, // ABlockTransferDstScalarPerVector_AK1 1, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM true, // ABlockLdsExtraM
S<4, 8, 8>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<4, 8, 4>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
1, // BBlockTransferDstScalarPerVector_BK1 1, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN true, // BBlockLdsExtraN
4, 1,
2, 1,
S<1, 32, 1, 8>, S<1, 16, 1, 8>,
8>; 8>;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
......
...@@ -163,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -163,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
...@@ -232,18 +233,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -232,18 +233,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
else else
{ {
constexpr auto A_KRow = WmmaK / K1; constexpr auto A_KRow = 2;
const auto A_KWmma = K / WmmaK; constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock; const auto M0 = M / MPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_gemmm_gemmk_desc, in_gemmm_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)), make_tuple(make_unmerge_transform(make_tuple(
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))), make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
} }
...@@ -275,18 +279,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -275,18 +279,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
else else
{ {
constexpr auto B_KRow = WmmaK / K1; constexpr auto B_KRow = 2;
const auto B_KWmma = K / WmmaK; constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock; const auto N0 = N / NPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor( return transform_tensor_descriptor(
wei_gemmn_gemmk_desc, wei_gemmn_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)), make_tuple(make_unmerge_transform(make_tuple(
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))), make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
} }
...@@ -556,7 +563,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -556,7 +563,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
else else
{ {
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5); arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
} }
}(); }();
...@@ -884,9 +891,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -884,9 +891,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<< KPerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", " << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", " << K1 << ", "
<< MPerWmma << ", "
<< NPerWmma << ", "
<< MRepeat << ", "
<< NRepeat
<< ">"
<< " AEnableLds: "
<< AEnableLds << ", "
<< "BEnableLds: "
<< BEnableLds << ", "
<< "ABlockTransferSrcScalarPerVector: "
<< ABlockTransferSrcScalarPerVector << ", " << ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << "BBlockTransferSrcScalarPerVector: "
<< ">"; << BBlockTransferSrcScalarPerVector;
// clang-format on // clang-format on
return str.str(); return str.str();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment