Commit c72a0d3e authored by Alan Turner's avatar Alan Turner
Browse files

Add Descriptor class and Run method to device_gemm_multiple_d_xdl_cshuffle.hpp...

Add Descriptor class and Run method to device_gemm_multiple_d_xdl_cshuffle.hpp and make BlockToCTileMap_M00_N0_M01Adapt constexpr
parent fa998675
...@@ -698,6 +698,120 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -698,6 +698,120 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return str.str(); return str.str();
} }
template <class ADesc, class BDesc, class DsDesc, class EDesc>
struct Descriptor
{
static constexpr auto pad_ds_tuple()
{
return transform_tuples(
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
DsDesc{});
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
pad_ds_tuple()))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
using Block2ETileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
Block2ETileMap block_2_etile_map;
bool has_main_k_block_loop = false;
bool is_valid = false;
constexpr Descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e)
: a_grid_desc_ak0_m_ak1{GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
DeviceOp::matrix_padder.PadADescriptor_M_K(a))},
b_grid_desc_bk0_n_bk1{GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
DeviceOp::matrix_padder.PadBDescriptor_N_K(b))},
ds_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
transform_tuples(
[&](auto d) constexpr {
return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
},
ds))},
e_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DeviceOp::matrix_padder.PadCDescriptor_M_N(e))},
block_2_etile_map{GridwiseGemm::MakeDefaultBlock2ETileMap(
DeviceOp::matrix_padder.PadCDescriptor_M_N(e))},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
is_valid{GridwiseGemm::CheckValidity(
(DeviceOp::matrix_padder.PadADescriptor_M_K(a)),
DeviceOp::matrix_padder.PadBDescriptor_N_K(b),
transform_tuples(
[&](auto d) constexpr {
return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
},
ds),
DeviceOp::matrix_padder.PadCDescriptor_M_N(e),
block_2_etile_map)}
{
}
};
template <class ADesc, class BDesc, class DsDesc, class EDesc>
static constexpr auto make_descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e)
{
return Descriptor<ADesc, BDesc, DsDesc, EDesc>(a, b, ds, e);
}
template <class Desc, class DsPointer>
__device__ static void Run(Desc desc,
const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid)
{
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
assert(desc.is_valid);
if(desc.has_main_k_block_loop)
{
GridwiseGemm::template Run<true>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared_block,
AElementwiseOperation{},
BElementwiseOperation{},
CDEElementwiseOperation{},
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_etile_map);
}
else
{
GridwiseGemm::template Run<false>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared_block,
AElementwiseOperation{},
BElementwiseOperation{},
CDEElementwiseOperation{},
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_etile_map);
}
}
}; };
} // namespace device } // namespace device
......
...@@ -117,15 +117,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -117,15 +117,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8) index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n) : M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{ {
} }
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ __device__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
...@@ -203,13 +203,13 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -203,13 +203,13 @@ struct BlockToCTileMap_M00_N0_M01Adapt
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const const CTileDim& /* c_tile_dim */) const
{ {
return true; // always valid provided that user gets grid size from CalculateGridSize() return true; // always valid provided that user gets grid size from CalculateGridSize()
} }
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } __host__ __device__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private: private:
index_t M01_; index_t M01_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment