Commit 1e2a641e authored by Paul's avatar Paul
Browse files

Format

parent d1a1a28b
This diff is collapsed.
...@@ -59,14 +59,15 @@ constexpr auto to_ck_tensor() ...@@ -59,14 +59,15 @@ constexpr auto to_ck_tensor()
}); });
} }
template<class F> template <class F>
struct ck_function_adaptor : F struct ck_function_adaptor : F
{ {
template<class... Ts> template <class... Ts>
constexpr ck_function_adaptor(Ts&&... xs) : F(static_cast<Ts&&>(xs)...) constexpr ck_function_adaptor(Ts&&... xs) : F(static_cast<Ts&&>(xs)...)
{} {
}
template<class T, class... Ts> template <class T, class... Ts>
constexpr void operator()(T& out, Ts&&... xs) const constexpr void operator()(T& out, Ts&&... xs) const
{ {
out = static_cast<const F&>(*this)(static_cast<Ts&&>(xs)...); out = static_cast<const F&>(*this)(static_cast<Ts&&>(xs)...);
...@@ -75,9 +76,10 @@ struct ck_function_adaptor : F ...@@ -75,9 +76,10 @@ struct ck_function_adaptor : F
struct ck_nop struct ck_nop
{ {
template<class T> template <class T>
constexpr void operator()(T&) const constexpr void operator()(T&) const
{} {
}
}; };
} // namespace migraphx } // namespace migraphx
......
...@@ -40,28 +40,30 @@ __device__ void ck_gemm(A a, B b, E e, Ds... ds) ...@@ -40,28 +40,30 @@ __device__ void ck_gemm(A a, B b, E e, Ds... ds)
constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>()); constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
constexpr const auto b_grid_desc_n_k = gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<B>()); constexpr const auto b_grid_desc_n_k = gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<B>());
constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>()); constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
constexpr const auto ds_grid_desc_m_n = ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...); constexpr const auto ds_grid_desc_m_n =
constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n); ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
using GridwiseGemm = typename G::GridwiseGemm; using GridwiseGemm = typename G::GridwiseGemm;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
constexpr auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); constexpr auto a_grid_desc_ak0_m_ak1 =
constexpr auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
constexpr auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
constexpr auto ds_grid_desc_mblock_mperblock_nblock_nperblock = constexpr auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
ds_grid_desc_m_n);
constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock = constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
e_grid_desc_m_n);
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
constexpr const bool HasMainKBlockLoop = constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) * a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{})); GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(), GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(),
b.data(), b.data(),
ck::make_tuple(ds.data()...), ck::make_tuple(ds.data()...),
......
...@@ -156,7 +156,8 @@ template <typename ALayout, ...@@ -156,7 +156,8 @@ template <typename ALayout,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()> ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct CK_DeviceGemmMultipleD struct CK_DeviceGemmMultipleD
{ {
ck::tensor_operation::device::MatrixPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t> matrix_padder {MPerBlock, NPerBlock, KPerBlock}; ck::tensor_operation::device::MatrixPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t>
matrix_padder{MPerBlock, NPerBlock, KPerBlock};
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = ck::GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = ck::GridwiseGemmMultipleD_xdl_cshuffle<
...@@ -203,7 +204,7 @@ struct CK_DeviceGemmMultipleD ...@@ -203,7 +204,7 @@ struct CK_DeviceGemmMultipleD
LoopSched>; LoopSched>;
// return block_id to E matrix tile idx (m0, n0) mapping // return block_id to E matrix tile idx (m0, n0) mapping
template<class EGridDesc_M_N> template <class EGridDesc_M_N>
__device__ static constexpr auto __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n_) MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n_)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment