Commit 981b8549 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Format

parent ce82ed9b
...@@ -38,7 +38,12 @@ struct dpp_gemm_type<DppGemmInstr::dpp_f32_8x8x8_f16, ...@@ -38,7 +38,12 @@ struct dpp_gemm_type<DppGemmInstr::dpp_f32_8x8x8_f16,
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_dpp = n_per_dpp; static constexpr index_t num_thread_per_dpp = n_per_dpp;
template <index_t MPerWave, index_t NPerWave, index_t KPerWave, class FloatA, class FloatB, class FloatC> template <index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_dpp_f32_8x8x8_f16<MPerWave, NPerWave, KPerWave>::Run(a, b, reg_c); intrin_dpp_f32_8x8x8_f16<MPerWave, NPerWave, KPerWave>::Run(a, b, reg_c);
...@@ -71,8 +76,7 @@ struct dpp_gemm_type<DppGemmInstr::dpp_i32_8x8x8_i8, ...@@ -71,8 +76,7 @@ struct dpp_gemm_type<DppGemmInstr::dpp_i32_8x8x8_i8,
class FloatC> class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_dpp_i32_8x8x8_i8<MPerWave, NPerWave, KPerWave>::Run( intrin_dpp_i32_8x8x8_i8<MPerWave, NPerWave, KPerWave>::Run(a, b, reg_c);
a, b, reg_c);
} }
}; };
...@@ -106,15 +110,19 @@ struct DppGemmSelector ...@@ -106,15 +110,19 @@ struct DppGemmSelector
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static constexpr auto selected_dpp_gemm = static constexpr auto selected_dpp_gemm =
dpp_gemm_type<GetDppGemm<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>(), Number<64>{}>{}; dpp_gemm_type<GetDppGemm<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>(),
Number<64>{}>{};
__host__ __device__ constexpr DppGemmSelector() __host__ __device__ constexpr DppGemmSelector()
{ {
static_assert(selected_dpp_gemm.m_per_wave == 8, "Something went wrong, M per wave should be equal to 8"); static_assert(selected_dpp_gemm.m_per_wave == 8,
"Something went wrong, M per wave should be equal to 8");
static_assert(selected_dpp_gemm.n_per_wave == 8, "Something went wrong, N per wave should be equal to 8"); static_assert(selected_dpp_gemm.n_per_wave == 8,
"Something went wrong, N per wave should be equal to 8");
static_assert(selected_dpp_gemm.k_per_wave == 8, "Something went wrong, K per wave should be equal to 8"); static_assert(selected_dpp_gemm.k_per_wave == 8,
"Something went wrong, K per wave should be equal to 8");
} }
}; };
...@@ -162,8 +170,7 @@ struct DppGemm ...@@ -162,8 +170,7 @@ struct DppGemm
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_desc_mblockxrepeat_nblockxRepeat_mwave_nwave_mperdpp_nperdpp, c_desc_mblockxrepeat_nblockxRepeat_mwave_nwave_mperdpp_nperdpp,
make_tuple( make_tuple(make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(NBlockxRepeat), make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(MWave), make_pass_through_transform(MWave),
make_pass_through_transform(NWave), make_pass_through_transform(NWave),
...@@ -188,25 +195,28 @@ struct DppGemm ...@@ -188,25 +195,28 @@ struct DppGemm
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert( static_assert((is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
is_same<dst_type, float>::value) || is_same<dst_type, float>::value) ||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value && (is_same<src_type_a, int8_t>::value &&
is_same<dst_type, int32_t>::value), is_same<src_type_b, int8_t>::value && is_same<dst_type, int32_t>::value),
"base type couple must be (half, float) " "base type couple must be (half, float) "
"or (int8, int32)!"); "or (int8, int32)!");
if constexpr(!TransposeC) if constexpr(!TransposeC)
{ {
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(p_a_wave, p_b_wave, p_c_thread); dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(
p_a_wave, p_b_wave, p_c_thread);
} }
else else
{ {
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(p_b_wave, p_a_wave, p_c_thread); dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(
p_b_wave, p_a_wave, p_c_thread);
} }
} }
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % dpp_gemm_instr.wave_size; } __device__ static auto GetLaneId()
{
return get_thread_local_1d_id() % dpp_gemm_instr.wave_size;
}
__device__ static auto GetSubGroupId() __device__ static auto GetSubGroupId()
{ {
...@@ -222,5 +232,4 @@ struct DppGemm ...@@ -222,5 +232,4 @@ struct DppGemm
DppGemmSelector<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>{}; DppGemmSelector<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>{};
static constexpr auto dpp_gemm_instr = dpp_gemm.selected_dpp_gemm; static constexpr auto dpp_gemm_instr = dpp_gemm.selected_dpp_gemm;
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment