Commit 646659dd authored by Paul's avatar Paul
Browse files

Use const variables

parent 34520806
......@@ -36,12 +36,10 @@ namespace migraphx {
template <class G, class A, class B, class C>
__device__ void ck_gemm(const A& a, const B& b, const C& c)
{
constexpr G gemm{};
constexpr auto a_grid_desc_ak0_m_ak1 = gemm.MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>());
constexpr auto b_grid_desc_bk0_n_bk1 = gemm.MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>());
constexpr auto c_grid_desc_m_n = gemm.MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr auto block_2_ctile_map = gemm.MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
constexpr const auto a_grid_desc_ak0_m_ak1 = G::MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>());
constexpr const auto b_grid_desc_bk0_n_bk1 = G::MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>());
constexpr const auto c_grid_desc_m_n = G::MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr const auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
......@@ -49,21 +47,20 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1,
// c_grid_desc_m_n, block_2_ctile_map));
constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock =
constexpr const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
constexpr auto shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ char p_shared_block[shared_block_size];
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
constexpr bool HasMainKBlockLoop =
constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(A{}.get_shape().elements());
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(),
b.data(),
c.data(),
p_shared_block,
gemm.a_element_op,
gemm.b_element_op,
gemm.c_element_op,
G{}.a_element_op,
G{}.b_element_op,
G{}.c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment