Commit 981b8549 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Format

parent ce82ed9b
......@@ -38,7 +38,12 @@ struct dpp_gemm_type<DppGemmInstr::dpp_f32_8x8x8_f16,
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_dpp = n_per_dpp;
template <index_t MPerWave, index_t NPerWave, index_t KPerWave, class FloatA, class FloatB, class FloatC>
template <index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_dpp_f32_8x8x8_f16<MPerWave, NPerWave, KPerWave>::Run(a, b, reg_c);
......@@ -71,8 +76,7 @@ struct dpp_gemm_type<DppGemmInstr::dpp_i32_8x8x8_i8,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_dpp_i32_8x8x8_i8<MPerWave, NPerWave, KPerWave>::Run(
a, b, reg_c);
intrin_dpp_i32_8x8x8_i8<MPerWave, NPerWave, KPerWave>::Run(a, b, reg_c);
}
};
......@@ -106,15 +110,19 @@ struct DppGemmSelector
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static constexpr auto selected_dpp_gemm =
dpp_gemm_type<GetDppGemm<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>(), Number<64>{}>{};
dpp_gemm_type<GetDppGemm<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>(),
Number<64>{}>{};
__host__ __device__ constexpr DppGemmSelector()
{
static_assert(selected_dpp_gemm.m_per_wave == 8, "Something went wrong, M per wave should be equal to 8");
static_assert(selected_dpp_gemm.m_per_wave == 8,
"Something went wrong, M per wave should be equal to 8");
static_assert(selected_dpp_gemm.n_per_wave == 8, "Something went wrong, N per wave should be equal to 8");
static_assert(selected_dpp_gemm.n_per_wave == 8,
"Something went wrong, N per wave should be equal to 8");
static_assert(selected_dpp_gemm.k_per_wave == 8, "Something went wrong, K per wave should be equal to 8");
static_assert(selected_dpp_gemm.k_per_wave == 8,
"Something went wrong, K per wave should be equal to 8");
}
};
......@@ -162,8 +170,7 @@ struct DppGemm
return transform_tensor_descriptor(
c_desc_mblockxrepeat_nblockxRepeat_mwave_nwave_mperdpp_nperdpp,
make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_tuple(make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(MWave),
make_pass_through_transform(NWave),
......@@ -188,25 +195,28 @@ struct DppGemm
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
static_assert((is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
is_same<dst_type, float>::value) ||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value &&
is_same<dst_type, int32_t>::value),
(is_same<src_type_a, int8_t>::value &&
is_same<src_type_b, int8_t>::value && is_same<dst_type, int32_t>::value),
"base type couple must be (half, float) "
"or (int8, int32)!");
if constexpr(!TransposeC)
{
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(p_a_wave, p_b_wave, p_c_thread);
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(
p_a_wave, p_b_wave, p_c_thread);
}
else
{
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(p_b_wave, p_a_wave, p_c_thread);
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(
p_b_wave, p_a_wave, p_c_thread);
}
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % dpp_gemm_instr.wave_size; }
__device__ static auto GetLaneId()
{
return get_thread_local_1d_id() % dpp_gemm_instr.wave_size;
}
__device__ static auto GetSubGroupId()
{
......@@ -222,5 +232,4 @@ struct DppGemm
DppGemmSelector<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>{};
static constexpr auto dpp_gemm_instr = dpp_gemm.selected_dpp_gemm;
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment