Commit 6a97c046 authored by Alan Turner's avatar Alan Turner
Browse files

Add gpu-invoker to device_gemm_multiple_d_xdl_cshuffle

parent 1b62bfaa
......@@ -10,6 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
//#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
......@@ -89,6 +90,70 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt2() = default;
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt2(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ __device__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const index_t grid_size = M0 * N0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ __device__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private:
index_t M01_;
CGridDesc_M_N c_grid_desc_m_n_;
};
// GEMM:
// input : A[M, K]
// input : B[N, K]
......@@ -679,6 +744,108 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return str.str();
}
template<class ADesc,
class BDesc,
class EDesc,
class... DsDesc>
struct Descriptor
{
using AGridDesc_M_K = decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}));
using BGridDesc_N_K = decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}));
using DsGridDesc_M_N = remove_cvref_t<decltype(make_tuple(DeviceOp::matrix_padder.PadCDescriptor_M_N(DsDesc{})...))>;
using EGridDesc_M_N = remove_cvref_t<decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
__device__ constexpr Descriptor(DsDesc... dsdesc)
{
static_assert(GridwiseGemm::CheckValidity(AGridDesc_M_K{},
BGridDesc_N_K{},
DsGridDesc_M_N{},
EGridDesc_M_N{},
get_block_2_etile_map()));
}
constexpr auto get_a_grid_desc_ak0_m_ak1() const
{
return a_grid_desc_ak0_m_ak1;
}
constexpr auto get_b_grid_desc_bk0_n_bk1() const
{
return b_grid_desc_bk0_n_bk1;
}
constexpr auto get_ds_grid_desc_mblock_mperblock_nblock_nperblock() const
{
return ds_grid_desc_mblock_mperblock_nblock_nperblock;
}
constexpr auto get_e_grid_desc_mblock_mperblock_nblock_nperblock() const
{
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
static constexpr auto get_block_2_etile_map()
{
return BlockToCTileMap_M00_N0_M01Adapt2<MPerBlock, NPerBlock, EGridDesc_M_N>(
EGridDesc_M_N{});
}
constexpr bool has_main_k_block_loop() const
{
return GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
}
private:
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
};
template<class ADesc,
class BDesc,
class EDesc,
class... DsDesc>
struct GPU_Invoker
{
using Descriptor = DeviceOp::Descriptor<ADesc, BDesc, EDesc, DsDesc...>;
__device__ constexpr GPU_Invoker(DsDesc... dsdesc) {}
template<class DsPointer>
__device__ static void run(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid)
{
constexpr Descriptor desc;
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<desc.has_main_k_block_loop()>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared_block,
AElementwiseOperation{},
BElementwiseOperation{},
CDEElementwiseOperation{},
desc.get_a_grid_desc_ak0_m_ak1(),
desc.get_b_grid_desc_bk0_n_bk1(),
desc.get_ds_grid_desc_mblock_mperblock_nblock_nperblock(),
desc.get_e_grid_desc_mblock_mperblock_nblock_nperblock(),
desc.get_block_2_etile_map());
}
};
};
} // namespace device
......
......@@ -117,15 +117,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
__host__ __device__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
......@@ -159,13 +159,13 @@ struct BlockToCTileMap_M00_N0_M01Adapt
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
__host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
__host__ __device__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private:
index_t M01_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment