Commit 981b8549 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Format

parent ce82ed9b
......@@ -38,7 +38,12 @@ struct dpp_gemm_type<DppGemmInstr::dpp_f32_8x8x8_f16,
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_dpp = n_per_dpp;
template <index_t MPerWave, index_t NPerWave, index_t KPerWave, class FloatA, class FloatB, class FloatC>
template <index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_dpp_f32_8x8x8_f16<MPerWave, NPerWave, KPerWave>::Run(a, b, reg_c);
......@@ -47,19 +52,19 @@ struct dpp_gemm_type<DppGemmInstr::dpp_f32_8x8x8_f16,
template <index_t WaveSize>
struct dpp_gemm_type<DppGemmInstr::dpp_i32_8x8x8_i8,
WaveSize,
typename std::enable_if_t<WaveSize == 64>>
WaveSize,
typename std::enable_if_t<WaveSize == 64>>
{
// * DPP GEMM setup
static constexpr index_t waves_per_wg = 4;
static constexpr index_t m_per_dpp = 8;
static constexpr index_t n_per_dpp = 8;
static constexpr index_t k_per_dpp = 8;
static constexpr index_t dpp_per_wave = 8;
static constexpr index_t src_a_data_size = 1;
static constexpr index_t src_b_data_size = 1;
static constexpr index_t acc_data_size = 4;
static constexpr index_t wave_size = WaveSize;
static constexpr index_t waves_per_wg = 4;
static constexpr index_t m_per_dpp = 8;
static constexpr index_t n_per_dpp = 8;
static constexpr index_t k_per_dpp = 8;
static constexpr index_t dpp_per_wave = 8;
static constexpr index_t src_a_data_size = 1;
static constexpr index_t src_b_data_size = 1;
static constexpr index_t acc_data_size = 4;
static constexpr index_t wave_size = WaveSize;
// * Thread mapping inside wave, num_thread_per_dpp always alone N direction
static constexpr index_t num_thread_per_dpp = n_per_dpp;
......@@ -71,8 +76,7 @@ struct dpp_gemm_type<DppGemmInstr::dpp_i32_8x8x8_i8,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_dpp_i32_8x8x8_i8<MPerWave, NPerWave, KPerWave>::Run(
a, b, reg_c);
intrin_dpp_i32_8x8x8_i8<MPerWave, NPerWave, KPerWave>::Run(a, b, reg_c);
}
};
......@@ -106,15 +110,19 @@ struct DppGemmSelector
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static constexpr auto selected_dpp_gemm =
dpp_gemm_type<GetDppGemm<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>(), Number<64>{}>{};
dpp_gemm_type<GetDppGemm<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>(),
Number<64>{}>{};
__host__ __device__ constexpr DppGemmSelector()
{
static_assert(selected_dpp_gemm.m_per_wave == 8, "Something went wrong, M per wave should be equal to 8");
static_assert(selected_dpp_gemm.m_per_wave == 8,
"Something went wrong, M per wave should be equal to 8");
static_assert(selected_dpp_gemm.n_per_wave == 8, "Something went wrong, N per wave should be equal to 8");
static_assert(selected_dpp_gemm.n_per_wave == 8,
"Something went wrong, N per wave should be equal to 8");
static_assert(selected_dpp_gemm.k_per_wave == 8, "Something went wrong, K per wave should be equal to 8");
static_assert(selected_dpp_gemm.k_per_wave == 8,
"Something went wrong, K per wave should be equal to 8");
}
};
......@@ -162,13 +170,12 @@ struct DppGemm
return transform_tensor_descriptor(
c_desc_mblockxrepeat_nblockxRepeat_mwave_nwave_mperdpp_nperdpp,
make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(MWave),
make_pass_through_transform(NWave),
make_pass_through_transform(Number<dpp_gemm_instr.dpp_per_wave>{}),
make_pass_through_transform(Number<dpp_gemm_instr.num_thread_per_dpp>{})),
make_tuple(make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(MWave),
make_pass_through_transform(NWave),
make_pass_through_transform(Number<dpp_gemm_instr.dpp_per_wave>{}),
make_pass_through_transform(Number<dpp_gemm_instr.num_thread_per_dpp>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -188,25 +195,28 @@ struct DppGemm
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
is_same<dst_type, float>::value) ||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value &&
is_same<dst_type, int32_t>::value),
"base type couple must be (half, float) "
"or (int8, int32)!");
static_assert((is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
is_same<dst_type, float>::value) ||
(is_same<src_type_a, int8_t>::value &&
is_same<src_type_b, int8_t>::value && is_same<dst_type, int32_t>::value),
"base type couple must be (half, float) "
"or (int8, int32)!");
if constexpr(!TransposeC)
{
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(p_a_wave, p_b_wave, p_c_thread);
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(
p_a_wave, p_b_wave, p_c_thread);
}
else
else
{
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(p_b_wave, p_a_wave, p_c_thread);
dpp_gemm_instr.template run<MPerWave, NPerWave, KPerWave>(
p_b_wave, p_a_wave, p_c_thread);
}
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % dpp_gemm_instr.wave_size; }
__device__ static auto GetLaneId()
{
return get_thread_local_1d_id() % dpp_gemm_instr.wave_size;
}
__device__ static auto GetSubGroupId()
{
......@@ -222,5 +232,4 @@ struct DppGemm
DppGemmSelector<src_type_a, src_type_b, dst_type, MPerWave, NPerWave, KPerWave>{};
static constexpr auto dpp_gemm_instr = dpp_gemm.selected_dpp_gemm;
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment