Commit 2d18473f authored by Paul's avatar Paul
Browse files

Format

parent 5dd55ba5
......@@ -36,10 +36,12 @@ namespace migraphx {
template <class G, class A, class B, class C>
__device__ void ck_gemm(const A& a, const B& b, const C& c)
{
constexpr const auto a_grid_desc_ak0_m_ak1 = G::MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>());
constexpr const auto b_grid_desc_bk0_n_bk1 = G::MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>());
constexpr const auto c_grid_desc_m_n = G::MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr const auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
constexpr const auto a_grid_desc_ak0_m_ak1 =
G::MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>());
constexpr const auto b_grid_desc_bk0_n_bk1 =
G::MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>());
constexpr const auto c_grid_desc_m_n = G::MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr const auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
......@@ -52,7 +54,9 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
constexpr const bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) * a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(),
b.data(),
c.data(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment