Commit 2d18473f authored by Paul's avatar Paul
Browse files

Format

parent 5dd55ba5
...@@ -36,10 +36,12 @@ namespace migraphx { ...@@ -36,10 +36,12 @@ namespace migraphx {
template <class G, class A, class B, class C> template <class G, class A, class B, class C>
__device__ void ck_gemm(const A& a, const B& b, const C& c) __device__ void ck_gemm(const A& a, const B& b, const C& c)
{ {
constexpr const auto a_grid_desc_ak0_m_ak1 = G::MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>()); constexpr const auto a_grid_desc_ak0_m_ak1 =
constexpr const auto b_grid_desc_bk0_n_bk1 = G::MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>()); G::MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>());
constexpr const auto c_grid_desc_m_n = G::MakeCGridDescriptor_M_N(to_ck_tensor<C>()); constexpr const auto b_grid_desc_bk0_n_bk1 =
constexpr const auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_grid_desc_m_n); G::MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>());
constexpr const auto c_grid_desc_m_n = G::MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr const auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1), using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
...@@ -52,7 +54,9 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c) ...@@ -52,7 +54,9 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
constexpr const bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) * a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{})); constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(), GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(),
b.data(), b.data(),
c.data(), c.data(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment