"git@developer.sourcefind.cn:OpenDAS/torch-spline-conv.git" did not exist on "dff93289bcc7673ed46d5192eb73196e3fe31efc"
Commit 5bf77d8b authored by aska-0096's avatar aska-0096
Browse files

clang-format

parent 2f88070a
...@@ -211,27 +211,20 @@ struct BlockwiseGemmWMMA ...@@ -211,27 +211,20 @@ struct BlockwiseGemmWMMA
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave // |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs // |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
I1,
I1,
Number<NRepeat>{},
I1,
I1,
MAccVgprs),
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride, make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride, Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride, Number<NRepeat>{} * MAccVgprs * AccStride,
MAccVgprs * AccStride, MAccVgprs * AccStride,
MAccVgprs * AccStride, MAccVgprs * AccStride,
MAccVgprs * AccStride, MAccVgprs * AccStride,
AccStride) AccStride));
); #if 0
#if 0
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave // |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs // |NThreadPerSubGroup |MAccVgprs
...@@ -242,7 +235,7 @@ struct BlockwiseGemmWMMA ...@@ -242,7 +235,7 @@ struct BlockwiseGemmWMMA
I1, I1,
NThreadPerSubGroup, NThreadPerSubGroup,
MAccVgprs)); MAccVgprs));
#endif #endif
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
......
...@@ -151,9 +151,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -151,9 +151,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto B0EnableLds_manu = true; static constexpr auto B0EnableLds_manu = true;
static constexpr auto B1EnableLds_manu = true; static constexpr auto B1EnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch >1); static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch >1); static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1);
static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch >1); static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1);
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>, Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>,
......
...@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false; static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch>1); static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch>1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
...@@ -467,7 +467,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -467,7 +467,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102") ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
{ {
printf("DeviceOp err: AccDataType"); printf("DeviceOp err: AccDataType");
return false; return false;
......
...@@ -177,8 +177,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -177,8 +177,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false; static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1); static constexpr auto AEnableLds =
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1); AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1);
static constexpr auto BEnableLds =
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
static constexpr auto conv_to_gemm_transformer = static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{}; TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
......
...@@ -104,11 +104,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16, ...@@ -104,11 +104,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4; m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
if constexpr(wave_size == 32) if constexpr(wave_size == 32)
...@@ -142,7 +138,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16, ...@@ -142,7 +138,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave = static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size *acc_pack_number/ wave_size / 4; m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC> template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
...@@ -182,11 +178,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16, ...@@ -182,11 +178,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4; m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
if constexpr(wave_size == 32) if constexpr(wave_size == 32)
...@@ -261,7 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8, ...@@ -261,7 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave = static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size *acc_pack_number / wave_size / 4; m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, template <index_t MPerWmma,
...@@ -496,13 +488,11 @@ struct WmmaGemm ...@@ -496,13 +488,11 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!"); "(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC) if constexpr(!TransposeC)
{ {
wmma_instr.template run<MPerWmma, NPerWmma>( wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
p_a_wave, p_b_wave, p_c_thread);
} }
else else
{ {
wmma_instr.template run<MPerWmma, NPerWmma>( wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
p_b_wave, p_a_wave, p_c_thread);
} }
} }
...@@ -555,7 +545,10 @@ struct WmmaGemm ...@@ -555,7 +545,10 @@ struct WmmaGemm
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{ {
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{}, Number<wmma_instr.acc_pack_number>{}); return make_tuple(I1,
I1,
Number<wmma_instr.num_acc_vgprs_per_wave>{},
Number<wmma_instr.acc_pack_number>{});
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment