Commit f00dab9f authored by aska-0096's avatar aska-0096
Browse files

conv A-skip lds ported

parent a38ce024
......@@ -82,13 +82,13 @@ using DeviceOpInstance =
GemmSpec,
256,
128,
256,
8,
128,
32,
8,
16,
16,
4,
4,
1,
8,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -105,7 +105,7 @@ using DeviceOpInstance =
true,
1,
1,
S<1, 32, 1, 8>,
S<1, 128, 1, 2>,
8>;
int main(int argc, char* argv[])
......
......@@ -53,13 +53,13 @@ using DeviceConvFwdInstance =
GemmSpec, // GemmSpecialization
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
4, // K0PerBlock
128, // NPerBlock
32, // KPerBlock
8, // K1
16, // MPerWMMA
16, // NPerWMMA
4, // MRepeat
4, // NRepeat
1, // MRepeat
8, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
......@@ -76,7 +76,7 @@ using DeviceConvFwdInstance =
true, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
S<1, 128, 1, 2>,
8>;
template <ck::index_t NDimSpatial>
......
......@@ -16,6 +16,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck {
namespace tensor_operation {
......@@ -38,10 +40,10 @@ template <typename ALayout,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MPerWmma,
ck::index_t NPerWmma,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
......@@ -83,19 +85,35 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % K1 == 0);
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
const index_t K0 = K / K1;
static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
// Force enable LDS if uncommented following
// AEnableLds = true;
// BEnableLds = true;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Describe how data read from Global memory
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
const auto a_grid_desc_mraw_kraw =
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
#ifdef ENABLE_COLMAJOR
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
......@@ -105,25 +123,35 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
#endif
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
assert(K % K1 == 0);
if constexpr(AEnableLds)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const index_t K0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto A_KRow = WmmaK / K1;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
}
......@@ -216,7 +244,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
......@@ -231,7 +259,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
DsDataType,
EDataType,
// InMemory Data Descriptor
AGridDesc_K0_M_K1,
AGridDesc,
BGridDesc_K0_N_K1,
DsGridDesc_M_N,
EGridDesc_M_N,
......@@ -243,9 +271,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
// Tiling Family
MPerBlock,
NPerBlock,
K0PerBlock,
MPerWMMA,
NPerWMMA,
KPerBlock,
MPerWmma,
NPerWmma,
K1,
MRepeat,
NRepeat,
......@@ -258,6 +286,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
......@@ -266,6 +295,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
......@@ -298,7 +328,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_k0_m_k1_{},
a_grid_desc{},
b_grid_desc_k0_n_k1_{},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{},
......@@ -311,7 +341,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
a_grid_desc = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
......@@ -328,7 +358,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
if(GridwiseOp::CheckValidity(a_grid_desc_k0_m_k1_,
if(GridwiseOp::CheckValidity(a_grid_desc,
b_grid_desc_k0_n_k1_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
......@@ -351,7 +381,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType* p_e_grid_;
// Tensor Descriptors
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
AGridDesc a_grid_desc;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
......@@ -382,9 +412,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
{
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.a_grid_desc{" << arg.a_grid_desc.GetLength(I0)
<< ", " << arg.a_grid_desc.GetLength(I1) << ", "
<< arg.a_grid_desc.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
......@@ -396,7 +426,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
#endif
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
......@@ -410,7 +440,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
float ave_time = 0;
......@@ -422,7 +452,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BDataType,
typename GridwiseOp::DsGridPointer,
EDataType,
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<typename DeviceOp::AGridDesc>,
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
......@@ -444,7 +474,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.a_grid_desc,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -461,7 +491,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BDataType,
typename GridwiseOp::DsGridPointer,
EDataType,
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<typename DeviceOp::AGridDesc>,
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
......@@ -483,7 +513,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.a_grid_desc,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -524,7 +554,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
return false;
}
return GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
return GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
......@@ -630,10 +660,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< KPerBlock << ", "
<< K1 << ", "
<< MPerWMMA << ", "
<< NPerWMMA << ", "
<< MPerWmma << ", "
<< NPerWmma << ", "
<< MRepeat << ", "
<< NRepeat
<< ">"
......
......@@ -112,10 +112,10 @@ template <index_t NDimSpatial,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MPerWmma,
ck::index_t NPerWmma,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
......@@ -161,7 +161,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t KPerBlock = K0PerBlock * K1;
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
// Force enable LDS if uncommented following
// AEnableLds = true;
// BEnableLds = true;
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
......@@ -171,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
MakeAGridDescriptor(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
......@@ -197,12 +211,41 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
return in_gemmm_gemmk_desc;
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
assert(K % K1 == 0);
if constexpr(AEnableLds)
{
const index_t K0 = K / K1;
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto A_KRow = WmmaK / K1;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock;
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
}
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const auto wei_gemmnraw_gemmkraw_desc =
......@@ -212,7 +255,17 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
return wei_gemmn_gemmk_desc;
const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
const auto BK1 = K1;
const auto BK0 = K / BK1;
return transform_tensor_descriptor(wei_gemmn_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename ELay>
......@@ -245,50 +298,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
// desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK1 = K1;
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK1 = K1;
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}));
using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}));
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}));
// GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
......@@ -300,7 +314,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
DsDataType,
EDataType,
// InMemory Data Descriptor
AGridDesc_AK0_M_AK1,
AGridDesc,
BGridDesc_BK0_N_BK1,
DsGridDesc_M_N,
EGridDesc_M_N,
......@@ -312,9 +326,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// Tiling Family
MPerBlock,
NPerBlock,
K0PerBlock,
MPerWMMA,
NPerWMMA,
KPerBlock,
MPerWmma,
NPerWmma,
K1,
MRepeat,
NRepeat,
......@@ -327,6 +341,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
AEnableLds,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
......@@ -335,6 +350,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BEnableLds,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
......@@ -375,7 +391,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -385,13 +404,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
......@@ -443,8 +457,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
std::cout << "A[M, K]: " << a_grid_desc << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
......@@ -459,13 +473,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
AGridDesc a_grid_desc;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
......@@ -514,7 +526,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
......@@ -528,7 +540,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::AGridDesc,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -549,7 +561,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_,
arg.a_grid_desc,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -719,7 +731,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
// check Gridwise GEMM
return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
return GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment