Commit 1e2a641e authored by Paul's avatar Paul
Browse files

Format

parent d1a1a28b
This diff is collapsed.
......@@ -59,14 +59,15 @@ constexpr auto to_ck_tensor()
});
}
template<class F>
template <class F>
struct ck_function_adaptor : F
{
template<class... Ts>
template <class... Ts>
constexpr ck_function_adaptor(Ts&&... xs) : F(static_cast<Ts&&>(xs)...)
{}
{
}
template<class T, class... Ts>
template <class T, class... Ts>
constexpr void operator()(T& out, Ts&&... xs) const
{
out = static_cast<const F&>(*this)(static_cast<Ts&&>(xs)...);
......@@ -75,9 +76,10 @@ struct ck_function_adaptor : F
struct ck_nop
{
template<class T>
template <class T>
constexpr void operator()(T&) const
{}
{
}
};
} // namespace migraphx
......
......@@ -40,28 +40,30 @@ __device__ void ck_gemm(A a, B b, E e, Ds... ds)
constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
constexpr const auto b_grid_desc_n_k = gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<B>());
constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
constexpr const auto ds_grid_desc_m_n = ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
constexpr const auto ds_grid_desc_m_n =
ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
using GridwiseGemm = typename G::GridwiseGemm;
// tensor descriptors for block/thread-wise copy
constexpr auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
constexpr auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
constexpr auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
constexpr auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
constexpr auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n);
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n);
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) * a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(),
b.data(),
ck::make_tuple(ds.data()...),
......
......@@ -156,7 +156,8 @@ template <typename ALayout,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct CK_DeviceGemmMultipleD
{
ck::tensor_operation::device::MatrixPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t> matrix_padder {MPerBlock, NPerBlock, KPerBlock};
ck::tensor_operation::device::MatrixPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t>
matrix_padder{MPerBlock, NPerBlock, KPerBlock};
// GridwiseGemm
using GridwiseGemm = ck::GridwiseGemmMultipleD_xdl_cshuffle<
......@@ -203,7 +204,7 @@ struct CK_DeviceGemmMultipleD
LoopSched>;
// return block_id to E matrix tile idx (m0, n0) mapping
template<class EGridDesc_M_N>
template <class EGridDesc_M_N>
__device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n_)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment