Commit d562e265 authored by Paul's avatar Paul
Browse files

Format

parent c51b3d29
...@@ -113,11 +113,12 @@ static std::vector<tuning_entry> read_tuning(const std::string& s) ...@@ -113,11 +113,12 @@ static std::vector<tuning_entry> read_tuning(const std::string& s)
static std::size_t get_tuning_for(const std::vector<shape>& inputs) static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{ {
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, "")); static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if (tuning.empty()) if(tuning.empty())
std::cout << "*********** Warning: No CK tuning!" << std::endl; std::cout << "*********** Warning: No CK tuning!" << std::endl;
auto it = std::find_if( auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; }); tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end()) { if(it == tuning.end())
{
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl; std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
return 4; return 4;
} }
...@@ -163,7 +164,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -163,7 +164,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
hip_compile_options options; hip_compile_options options;
auto block_size = get_block_size(instance); auto block_size = get_block_size(instance);
auto grid_size = get_grid_size(instance, m, n); auto grid_size = get_grid_size(instance, m, n);
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs; options.inputs = inputs;
options.output = c_shape; options.output = c_shape;
......
...@@ -13,23 +13,23 @@ ...@@ -13,23 +13,23 @@
namespace migraphx { namespace migraphx {
namespace detail { namespace detail {
template<class T> template <class T>
struct to_ck_type_impl struct to_ck_type_impl
{ {
using type = T; using type = T;
}; };
template<> template <>
struct to_ck_type_impl<migraphx::half> struct to_ck_type_impl<migraphx::half>
{ {
using type = ck::half_t; using type = ck::half_t;
}; };
template<class Shape> template <class Shape>
constexpr bool is_row_major() constexpr bool is_row_major()
{ {
constexpr auto strides = Shape{}.strides; constexpr auto strides = Shape{}.strides;
MIGRAPHX_ASSERT(strides.size() >= 2); MIGRAPHX_ASSERT(strides.size() >= 2);
if (strides.back() == 1) if(strides.back() == 1)
{ {
MIGRAPHX_ASSERT(not Shape{}.is_trasnposed()); MIGRAPHX_ASSERT(not Shape{}.is_trasnposed());
return true; return true;
...@@ -41,18 +41,21 @@ constexpr bool is_row_major() ...@@ -41,18 +41,21 @@ constexpr bool is_row_major()
} // namespace detail } // namespace detail
template<class T> template <class T>
using to_ck_type = typename detail::to_ck_type_impl<T>::type; using to_ck_type = typename detail::to_ck_type_impl<T>::type;
template<class Shape> template <class Shape>
using to_ck_gemm_layout = conditional_t<detail::is_row_major<get_shape_c<Shape>>(), ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor>; using to_ck_gemm_layout = conditional_t<detail::is_row_major<get_shape_c<Shape>>(),
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor>;
template<class Tensor> template <class Tensor>
constexpr auto to_ck_tensor() constexpr auto to_ck_tensor()
{ {
constexpr auto s = get_shape_c<Tensor>{}; constexpr auto s = get_shape_c<Tensor>{};
return sequence(s.lens.size(), [](auto... is) { return sequence(s.lens.size(), [](auto... is) {
return ck::make_naive_tensor_descriptor(ck::make_tuple(s.lens[is]...), ck::make_tuple(s.strides[is]...)); return ck::make_naive_tensor_descriptor(ck::make_tuple(s.lens[is]...),
ck::make_tuple(s.strides[is]...));
}); });
} }
......
...@@ -33,14 +33,13 @@ ...@@ -33,14 +33,13 @@
namespace migraphx { namespace migraphx {
template <class G, class A, class B, class C> template <class G, class A, class B, class C>
__device__ void ck_gemm(const A& a, const B& b, const C& c) __device__ void ck_gemm(const A& a, const B& b, const C& c)
{ {
constexpr auto a_desc = to_ck_tensor<A>(); constexpr auto a_desc = to_ck_tensor<A>();
constexpr auto b_desc = to_ck_tensor<B>(); constexpr auto b_desc = to_ck_tensor<B>();
constexpr auto c_desc = to_ck_tensor<C>(); constexpr auto c_desc = to_ck_tensor<C>();
constexpr auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_desc); constexpr auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_desc);
using GridwiseGemm = typename G::template Make<a_desc, b_desc, c_desc>; using GridwiseGemm = typename G::template Make<a_desc, b_desc, c_desc>;
// static_assert(GridwiseGemm::CheckValidity(a_desc, b_desc, c_desc, block_2_ctile_map)); // static_assert(GridwiseGemm::CheckValidity(a_desc, b_desc, c_desc, block_2_ctile_map));
...@@ -48,11 +47,11 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c) ...@@ -48,11 +47,11 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_desc); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_desc);
constexpr auto shared_block_size = constexpr auto shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ char p_shared_block[shared_block_size]; __shared__ char p_shared_block[shared_block_size];
constexpr bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(A{}.get_shape().elements()); constexpr bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(A{}.get_shape().elements());
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(), GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(),
b.data(), b.data(),
c.data(), c.data(),
...@@ -64,7 +63,6 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c) ...@@ -64,7 +63,6 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
b_desc, b_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -152,9 +152,8 @@ template <typename ALayout, ...@@ -152,9 +152,8 @@ template <typename ALayout,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()> ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct CKDeviceGemm struct CKDeviceGemm
{ {
template<class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N> template <class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N>
using Make = using Make = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -203,7 +202,7 @@ struct CKDeviceGemm ...@@ -203,7 +202,7 @@ struct CKDeviceGemm
static constexpr auto BOp() { return BElementwiseOperation{}; } static constexpr auto BOp() { return BElementwiseOperation{}; }
static constexpr auto COp() { return CElementwiseOperation{}; } static constexpr auto COp() { return CElementwiseOperation{}; }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
template<class CGridDesc_M_N> template <class CGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment