Commit 981b8549 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Format

parent ce82ed9b
...@@ -38,7 +38,12 @@ struct dpp_gemm_type<DppGemmInstr::dpp_f32_8x8x8_f16, ...@@ -38,7 +38,12 @@ struct dpp_gemm_type<DppGemmInstr::dpp_f32_8x8x8_f16,
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_dpp = n_per_dpp; static constexpr index_t num_thread_per_dpp = n_per_dpp;
template <index_t MPerWave, index_t NPerWave, index_t KPerWave, class FloatA, class FloatB, class FloatC> template <index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_dpp_f32_8x8x8_f16<MPerWave, NPerWave, KPerWave>::Run(a, b, reg_c); intrin_dpp_f32_8x8x8_f16<MPerWave, NPerWave, KPerWave>::Run(a, b, reg_c);
...@@ -47,19 +52,19 @@ struct dpp_gemm_type<DppGemmInstr::dpp_f32_8x8x8_f16, ...@@ -47,19 +52,19 @@ struct dpp_gemm_type<DppGemmInstr::dpp_f32_8x8x8_f16,
template <index_t WaveSize> template <index_t WaveSize>
struct dpp_gemm_type<DppGemmInstr::dpp_i32_8x8x8_i8, struct dpp_gemm_type<DppGemmInstr::dpp_i32_8x8x8_i8,
WaveSize, WaveSize,
typename std::enable_if_t<WaveSize == 64>> typename std::enable_if_t<WaveSize == 64>>
{ {
// * DPP GEMM setup // * DPP GEMM setup
static constexpr index_t waves_per_wg = 4; static constexpr index_t waves_per_wg = 4;
static constexpr index_t m_per_dpp = 8; static constexpr index_t m_per_dpp = 8;
static constexpr index_t n_per_dpp = 8; static constexpr index_t n_per_dpp = 8;
static constexpr index_t k_per_dpp = 8; static constexpr index_t k_per_dpp = 8;
static constexpr index_t dpp_per_wave = 8; static constexpr index_t dpp_per_wave = 8;
static constexpr index_t src_a_data_size = 1; static constexpr index_t src_a_data_size = 1;
static constexpr index_t src_b_data_size = 1; static constexpr index_t src_b_data_size = 1;
static constexpr index_t acc_data_size = 4; static constexpr index_t acc_data_size = 4;
static constexpr index_t wave_size = WaveSize; static constexpr index_t wave_size = WaveSize;
// * Thread mapping inside wave, num_thread_per_dpp always alone N direction // * Thread mapping inside wave, num_thread_per_dpp always alone N direction
static constexpr index_t num_thread_per_dpp = n_per_dpp; static constexpr index_t num_thread_per_dpp = n_per_dpp;
...@@ -71,8 +76,7 @@ struct dpp_gemm_type<DppGemmInstr::dpp_i32_8x8x8_i8, ...@@ -71,8 +76,7 @@ struct dpp_gemm_type<DppGemmInstr::dpp_i32_8x8x8_i8,
class FloatC> class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_dpp_i32_8x8x8_i8<MPerWave, NPerWave, KPerWave>::Run( intrin_dpp_i32_8x8x8_i8<MPerWave, NPerWave, KPerWave>::Run(a, b, reg_c);
a, b, reg_c);
} }
}; };
...@@ -106,15 +110,19 @@ struct DppGemmSelector ...@@ -106,15 +110,19 @@ struct DppGemmSelector
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static constexpr auto selected_dpp_gemm = static constexpr auto selected_dpp_gemm =
dpp_gemm_type<GetDppGemm<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>(), Number<64>{}>{}; dpp_gemm_type<GetDppGemm<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>(),
Number<64>{}>{};
__host__ __device__ constexpr DppGemmSelector() __host__ __device__ constexpr DppGemmSelector()
{ {
static_assert(selected_dpp_gemm.m_per_wave == 8, "Something went wrong, M per wave should be equal to 8"); static_assert(selected_dpp_gemm.m_per_wave == 8,
"Something went wrong, M per wave should be equal to 8");
static_assert(selected_dpp_gemm.n_per_wave == 8, "Something went wrong, N per wave should be equal to 8"); static_assert(selected_dpp_gemm.n_per_wave == 8,
"Something went wrong, N per wave should be equal to 8");
static_assert(selected_dpp_gemm.k_per_wave == 8, "Something went wrong, K per wave should be equal to 8"); static_assert(selected_dpp_gemm.k_per_wave == 8,
"Something went wrong, K per wave should be equal to 8");
} }
}; };
...@@ -162,13 +170,12 @@ struct DppGemm ...@@ -162,13 +170,12 @@ struct DppGemm
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_desc_mblockxrepeat_nblockxRepeat_mwave_nwave_mperdpp_nperdpp, c_desc_mblockxrepeat_nblockxRepeat_mwave_nwave_mperdpp_nperdpp,
make_tuple( make_tuple(make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MBlockxRepeat), make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NBlockxRepeat), make_pass_through_transform(MWave),
make_pass_through_transform(MWave), make_pass_through_transform(NWave),
make_pass_through_transform(NWave), make_pass_through_transform(Number<dpp_gemm_instr.dpp_per_wave>{}),
make_pass_through_transform(Number<dpp_gemm_instr.dpp_per_wave>{}), make_pass_through_transform(Number<dpp_gemm_instr.num_thread_per_dpp>{})),
make_pass_through_transform(Number<dpp_gemm_instr.num_thread_per_dpp>{})),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -188,25 +195,28 @@ struct DppGemm ...@@ -188,25 +195,28 @@ struct DppGemm
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert( static_assert((is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value && is_same<dst_type, float>::value) ||
is_same<dst_type, float>::value) || (is_same<src_type_a, int8_t>::value &&
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value && is_same<src_type_b, int8_t>::value && is_same<dst_type, int32_t>::value),
is_same<dst_type, int32_t>::value), "base type couple must be (half, float) "
"base type couple must be (half, float) " "or (int8, int32)!");
"or (int8, int32)!");
if constexpr(!TransposeC) if constexpr(!TransposeC)
{ {
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(p_a_wave, p_b_wave, p_c_thread); dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(
p_a_wave, p_b_wave, p_c_thread);
} }
else else
{ {
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(p_b_wave, p_a_wave, p_c_thread); dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(
p_b_wave, p_a_wave, p_c_thread);
} }
} }
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % dpp_gemm_instr.wave_size; } __device__ static auto GetLaneId()
{
return get_thread_local_1d_id() % dpp_gemm_instr.wave_size;
}
__device__ static auto GetSubGroupId() __device__ static auto GetSubGroupId()
{ {
...@@ -222,5 +232,4 @@ struct DppGemm ...@@ -222,5 +232,4 @@ struct DppGemm
DppGemmSelector<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>{}; DppGemmSelector<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>{};
static constexpr auto dpp_gemm_instr = dpp_gemm.selected_dpp_gemm; static constexpr auto dpp_gemm_instr = dpp_gemm.selected_dpp_gemm;
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment