Commit 646659dd authored by Paul's avatar Paul
Browse files

Use const variables

parent 34520806
...@@ -36,12 +36,10 @@ namespace migraphx { ...@@ -36,12 +36,10 @@ namespace migraphx {
template <class G, class A, class B, class C> template <class G, class A, class B, class C>
__device__ void ck_gemm(const A& a, const B& b, const C& c) __device__ void ck_gemm(const A& a, const B& b, const C& c)
{ {
constexpr G gemm{}; constexpr const auto a_grid_desc_ak0_m_ak1 = G::MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>());
constexpr const auto b_grid_desc_bk0_n_bk1 = G::MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>());
constexpr auto a_grid_desc_ak0_m_ak1 = gemm.MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>()); constexpr const auto c_grid_desc_m_n = G::MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr auto b_grid_desc_bk0_n_bk1 = gemm.MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>()); constexpr const auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
constexpr auto c_grid_desc_m_n = gemm.MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr auto block_2_ctile_map = gemm.MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1), using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
...@@ -49,21 +47,20 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c) ...@@ -49,21 +47,20 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, // static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1,
// c_grid_desc_m_n, block_2_ctile_map)); // c_grid_desc_m_n, block_2_ctile_map));
constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock = constexpr const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
constexpr auto shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_block[shared_block_size];
constexpr bool HasMainKBlockLoop = constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(A{}.get_shape().elements()); GridwiseGemm::CalculateHasMainKBlockLoop(A{}.get_shape().elements());
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(), GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(),
b.data(), b.data(),
c.data(), c.data(),
p_shared_block, p_shared_block,
gemm.a_element_op, G{}.a_element_op,
gemm.b_element_op, G{}.b_element_op,
gemm.c_element_op, G{}.c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment