Commit 34520806 authored by Paul's avatar Paul
Browse files

Format

parent 2ca29096
...@@ -43,8 +43,11 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c) ...@@ -43,8 +43,11 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
constexpr auto c_grid_desc_m_n = gemm.MakeCGridDescriptor_M_N(to_ck_tensor<C>()); constexpr auto c_grid_desc_m_n = gemm.MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr auto block_2_ctile_map = gemm.MakeDefaultBlock2CTileMap(c_grid_desc_m_n); constexpr auto block_2_ctile_map = gemm.MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1), decltype(b_grid_desc_bk0_n_bk1), decltype(c_grid_desc_m_n)>; using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1),
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n, block_2_ctile_map)); decltype(b_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_m_n)>;
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1,
// c_grid_desc_m_n, block_2_ctile_map));
constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
......
...@@ -149,8 +149,7 @@ template <typename ALayout, ...@@ -149,8 +149,7 @@ template <typename ALayout,
ck::index_t CShuffleNXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock, ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler() ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
>
struct CKDeviceGemm struct CKDeviceGemm
{ {
static constexpr auto I0 = ck::Number<0>{}; static constexpr auto I0 = ck::Number<0>{};
...@@ -158,9 +157,8 @@ struct CKDeviceGemm ...@@ -158,9 +157,8 @@ struct CKDeviceGemm
static constexpr auto I2 = ck::Number<2>{}; static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{}; static constexpr auto I3 = ck::Number<3>{};
template<class Descriptor> template <class Descriptor>
static constexpr auto static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const Descriptor& a_grid_desc_mraw_kraw)
MakeAGridDescriptor_AK0_M_AK1(const Descriptor& a_grid_desc_mraw_kraw)
{ {
const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0); const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0);
const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1); const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1);
...@@ -253,9 +251,8 @@ struct CKDeviceGemm ...@@ -253,9 +251,8 @@ struct CKDeviceGemm
} }
} }
template<class Descriptor> template <class Descriptor>
static constexpr auto static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const Descriptor& b_grid_desc_nraw_kraw)
MakeBGridDescriptor_BK0_N_BK1(const Descriptor& b_grid_desc_nraw_kraw)
{ {
const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0); const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1); const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1);
...@@ -348,9 +345,8 @@ struct CKDeviceGemm ...@@ -348,9 +345,8 @@ struct CKDeviceGemm
} }
} }
template<class Descriptor> template <class Descriptor>
static constexpr auto static constexpr auto MakeCGridDescriptor_M_N(const Descriptor& c_grid_desc_mraw_nraw)
MakeCGridDescriptor_M_N(const Descriptor& c_grid_desc_mraw_nraw)
{ {
const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0); const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1); const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1);
...@@ -408,7 +404,7 @@ struct CKDeviceGemm ...@@ -408,7 +404,7 @@ struct CKDeviceGemm
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N()); // using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
template<class CGridDesc_M_N> template <class CGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
...@@ -416,7 +412,7 @@ struct CKDeviceGemm ...@@ -416,7 +412,7 @@ struct CKDeviceGemm
c_grid_desc_m_n); c_grid_desc_m_n);
} }
template<class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N> template <class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N>
using GridwiseGemm = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment