Commit d562e265 authored by Paul's avatar Paul
Browse files

Format

parent c51b3d29
......@@ -113,11 +113,12 @@ static std::vector<tuning_entry> read_tuning(const std::string& s)
static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if (tuning.empty())
if(tuning.empty())
std::cout << "*********** Warning: No CK tuning!" << std::endl;
auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end()) {
if(it == tuning.end())
{
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
return 4;
}
......
......@@ -13,23 +13,23 @@
namespace migraphx {
namespace detail {
template<class T>
template <class T>
struct to_ck_type_impl
{
using type = T;
};
template<>
template <>
struct to_ck_type_impl<migraphx::half>
{
using type = ck::half_t;
};
template<class Shape>
template <class Shape>
constexpr bool is_row_major()
{
constexpr auto strides = Shape{}.strides;
MIGRAPHX_ASSERT(strides.size() >= 2);
if (strides.back() == 1)
if(strides.back() == 1)
{
MIGRAPHX_ASSERT(not Shape{}.is_trasnposed());
return true;
......@@ -41,18 +41,21 @@ constexpr bool is_row_major()
} // namespace detail
template<class T>
template <class T>
using to_ck_type = typename detail::to_ck_type_impl<T>::type;
template<class Shape>
using to_ck_gemm_layout = conditional_t<detail::is_row_major<get_shape_c<Shape>>(), ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor>;
template <class Shape>
using to_ck_gemm_layout = conditional_t<detail::is_row_major<get_shape_c<Shape>>(),
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor>;
template<class Tensor>
template <class Tensor>
constexpr auto to_ck_tensor()
{
constexpr auto s = get_shape_c<Tensor>{};
return sequence(s.lens.size(), [](auto... is) {
return ck::make_naive_tensor_descriptor(ck::make_tuple(s.lens[is]...), ck::make_tuple(s.strides[is]...));
return ck::make_naive_tensor_descriptor(ck::make_tuple(s.lens[is]...),
ck::make_tuple(s.strides[is]...));
});
}
......
......@@ -33,7 +33,6 @@
namespace migraphx {
template <class G, class A, class B, class C>
__device__ void ck_gemm(const A& a, const B& b, const C& c)
{
......@@ -48,11 +47,11 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_desc);
constexpr auto shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte();
constexpr auto shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ char p_shared_block[shared_block_size];
constexpr bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(A{}.get_shape().elements());
constexpr bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(A{}.get_shape().elements());
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(),
b.data(),
c.data(),
......@@ -64,7 +63,6 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
b_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
}
} // namespace migraphx
......
......@@ -152,9 +152,8 @@ template <typename ALayout,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct CKDeviceGemm
{
template<class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N>
using Make =
ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
template <class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N>
using Make = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
......@@ -203,7 +202,7 @@ struct CKDeviceGemm
static constexpr auto BOp() { return BElementwiseOperation{}; }
static constexpr auto COp() { return CElementwiseOperation{}; }
// return block_id to C matrix tile idx (m0, n0) mapping
template<class CGridDesc_M_N>
template <class CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment