Commit 34520806 authored by Paul's avatar Paul
Browse files

Format

parent 2ca29096
......@@ -43,8 +43,11 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
constexpr auto c_grid_desc_m_n = gemm.MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr auto block_2_ctile_map = gemm.MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1), decltype(b_grid_desc_bk0_n_bk1), decltype(c_grid_desc_m_n)>;
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n, block_2_ctile_map));
using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_m_n)>;
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1,
// c_grid_desc_m_n, block_2_ctile_map));
constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
......
......@@ -149,8 +149,7 @@ template <typename ALayout,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()
>
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct CKDeviceGemm
{
static constexpr auto I0 = ck::Number<0>{};
......@@ -158,9 +157,8 @@ struct CKDeviceGemm
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
template<class Descriptor>
static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const Descriptor& a_grid_desc_mraw_kraw)
template <class Descriptor>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const Descriptor& a_grid_desc_mraw_kraw)
{
const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0);
const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1);
......@@ -253,9 +251,8 @@ struct CKDeviceGemm
}
}
template<class Descriptor>
static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(const Descriptor& b_grid_desc_nraw_kraw)
template <class Descriptor>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const Descriptor& b_grid_desc_nraw_kraw)
{
const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1);
......@@ -348,9 +345,8 @@ struct CKDeviceGemm
}
}
template<class Descriptor>
static constexpr auto
MakeCGridDescriptor_M_N(const Descriptor& c_grid_desc_mraw_nraw)
template <class Descriptor>
static constexpr auto MakeCGridDescriptor_M_N(const Descriptor& c_grid_desc_mraw_nraw)
{
const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1);
......@@ -408,7 +404,7 @@ struct CKDeviceGemm
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// return block_id to C matrix tile idx (m0, n0) mapping
template<class CGridDesc_M_N>
template <class CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
......@@ -416,7 +412,7 @@ struct CKDeviceGemm
c_grid_desc_m_n);
}
template<class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N>
template <class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N>
using GridwiseGemm = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment