Commit 0ff1d1f8 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Remove hardcoded datatypes

parent 7b7dd69d
...@@ -11,9 +11,9 @@ namespace ck { ...@@ -11,9 +11,9 @@ namespace ck {
enum struct DppInstr enum struct DppInstr
{ {
dpp8_16x16x2 = 0, dpp8_f16_16x16x2 = 0,
dpp8_8x32x2, dpp8_f16_8x32x2,
dpp8_32x8x2 dpp8_f16_32x8x2
}; };
/** /**
...@@ -42,7 +42,7 @@ template <DppInstr instr> ...@@ -42,7 +42,7 @@ template <DppInstr instr>
struct dpp_type; struct dpp_type;
template <> template <>
struct dpp_type<DppInstr::dpp8_32x8x2> struct dpp_type<DppInstr::dpp8_f16_32x8x2>
{ {
static constexpr index_t wave_size = 32; static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8; static constexpr index_t lanegroup_size = 8;
...@@ -54,17 +54,25 @@ struct dpp_type<DppInstr::dpp8_32x8x2> ...@@ -54,17 +54,25 @@ struct dpp_type<DppInstr::dpp8_32x8x2>
static constexpr index_t n_per_thread = 1; static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2; static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true; static constexpr bool share_a = true;
using base_type = half_t;
template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC> template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
dpp8::RunGemm<m_per_lanegroup, n_per_lanegroup, k_per_dpp, FloatA, FloatB, FloatC, share_a>( dpp8::DppInstrRunner<m_per_thread,
a, b, reg_c); n_per_thread,
k_per_dpp,
base_type,
FloatA,
FloatB,
FloatC,
share_a>{}
.Run(a, b, reg_c);
} }
}; };
template <> template <>
struct dpp_type<DppInstr::dpp8_8x32x2> struct dpp_type<DppInstr::dpp8_f16_8x32x2>
{ {
static constexpr index_t wave_size = 32; static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8; static constexpr index_t lanegroup_size = 8;
...@@ -76,17 +84,25 @@ struct dpp_type<DppInstr::dpp8_8x32x2> ...@@ -76,17 +84,25 @@ struct dpp_type<DppInstr::dpp8_8x32x2>
static constexpr index_t n_per_thread = 1; static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2; static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true; static constexpr bool share_a = true;
using base_type = half_t;
template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC> template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
dpp8::RunGemm<m_per_lanegroup, n_per_lanegroup, k_per_dpp, FloatA, FloatB, FloatC, share_a>( dpp8::DppInstrRunner<m_per_thread,
a, b, reg_c); n_per_thread,
k_per_dpp,
base_type,
FloatA,
FloatB,
FloatC,
share_a>{}
.Run(a, b, reg_c);
} }
}; };
template <> template <>
struct dpp_type<DppInstr::dpp8_16x16x2> struct dpp_type<DppInstr::dpp8_f16_16x16x2>
{ {
static constexpr index_t wave_size = 32; static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8; static constexpr index_t lanegroup_size = 8;
...@@ -98,12 +114,20 @@ struct dpp_type<DppInstr::dpp8_16x16x2> ...@@ -98,12 +114,20 @@ struct dpp_type<DppInstr::dpp8_16x16x2>
static constexpr index_t n_per_thread = 1; static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2; static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true; static constexpr bool share_a = true;
using base_type = half_t;
template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC> template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
dpp8::RunGemm<m_per_lanegroup, n_per_lanegroup, k_per_dpp, FloatA, FloatB, FloatC, share_a>( dpp8::DppInstrRunner<m_per_thread,
a, b, reg_c); n_per_thread,
k_per_dpp,
base_type,
FloatA,
FloatB,
FloatC,
share_a>{}
.Run(a, b, reg_c);
} }
}; };
...@@ -116,19 +140,19 @@ struct DppSelector ...@@ -116,19 +140,19 @@ struct DppSelector
template <> template <>
static constexpr auto GetDpp<half_t, 8, 32>() static constexpr auto GetDpp<half_t, 8, 32>()
{ {
return DppInstr::dpp8_8x32x2; return DppInstr::dpp8_f16_8x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 16, 16>() static constexpr auto GetDpp<half_t, 16, 16>()
{ {
return DppInstr::dpp8_16x16x2; return DppInstr::dpp8_f16_16x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 32, 8>() static constexpr auto GetDpp<half_t, 32, 8>()
{ {
return DppInstr::dpp8_32x8x2; return DppInstr::dpp8_f16_32x8x2;
} }
static constexpr auto selected_dpp = dpp_type<GetDpp<base_type, MPerDpp, NPerDpp>()>{}; static constexpr auto selected_dpp = dpp_type<GetDpp<base_type, MPerDpp, NPerDpp>()>{};
......
...@@ -11,33 +11,57 @@ namespace ck { ...@@ -11,33 +11,57 @@ namespace ck {
namespace dpp8 { namespace dpp8 {
template <index_t MPerLanegroup, template <class ABDataType>
index_t NPerLanegroup, struct dpp_datatypes;
index_t KPerLanegroup,
class FloatA, template <>
class FloatB, struct dpp_datatypes<half_t>
class FloatVecC, {
// Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a
// single instruction.
using a_dtype = half_t;
using b_dtype = half_t;
using c_dtype = float;
static constexpr index_t k_per_instr = 2;
};
template <index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
class BaseInputType,
class AVecDataType,
class BVecDataType,
class CVecDataType,
bool ShareA> bool ShareA>
__device__ void RunGemm(const FloatA& a, const FloatB& b, FloatVecC& c_vec) struct DppInstrRunner
{ {
constexpr index_t c_dim = ShareA ? MPerLanegroup : NPerLanegroup; static constexpr auto datatypes_conf = dpp_datatypes<BaseInputType>{};
using ADataType = typename decltype(datatypes_conf)::a_dtype;
const vector_type<half_t, KPerLanegroup> a_vector{a}; using BDataType = typename decltype(datatypes_conf)::b_dtype;
const vector_type<half_t, KPerLanegroup> b_vector{b}; using CDataType = typename decltype(datatypes_conf)::c_dtype;
static_for<0, c_dim, 1>{}([&](auto c_idx) { __device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec)
float c = c_vec.template AsType<float>()(c_idx); {
// Next `c_idx` implies that we need to pull data from the next lane. constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread;
constexpr index_t source_lane = c_idx;
static_for<0, KPerLanegroup / 2, 1>{}([&](auto k_chunk) { const vector_type<ADataType, KPerThread> a_vector{a_vec};
const auto a_half2 = a_vector.template AsType<half2_t>()[k_chunk]; const vector_type<BDataType, KPerThread> b_vector{b_vec};
const auto b_half2 = b_vector.template AsType<half2_t>()[k_chunk];
ck::dpp8::inner_product_dpp<half2_t, half2_t, float, source_lane, ShareA>( static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) {
a_half2, b_half2, c); float c = c_vec.template AsType<CDataType>()(c_idx);
// Next `c_idx` implies that we need to pull data from the next lane.
constexpr index_t source_lane = c_idx;
static_for<0, KPerThread / datatypes_conf.k_per_instr, 1>{}([&](auto k_chunk) {
const auto a_k_vec = a_vector.template AsType<AVecDataType>()[k_chunk];
const auto b_k_vec = b_vector.template AsType<BVecDataType>()[k_chunk];
ck::dpp8::
inner_product_dpp<AVecDataType, BVecDataType, CDataType, source_lane, ShareA>(
a_k_vec, b_k_vec, c);
});
c_vec.template AsType<CDataType>()(c_idx) = c;
}); });
c_vec.template AsType<float>()(c_idx) = c; }
}); };
}
} // namespace dpp8 } // namespace dpp8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment