"...composable_kernel.git" did not exist on "e921e1f08dc04bc4bdf8a1efeb2c1623ff336a6d"
Commit 02bf2be0 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent dfbe7e20
...@@ -13,6 +13,9 @@ namespace ck { ...@@ -13,6 +13,9 @@ namespace ck {
// GemmK = C * Y * X // GemmK = C * Y * X
template <index_t GemmMPerBlock, template <index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmKPerWave,
typename... Wei, typename... Wei,
typename... In, typename... In,
typename... Out, typename... Out,
...@@ -106,9 +109,17 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -106,9 +109,17 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<float, GemmMPerWave, GemmNPerWave, GemmKPerWave>{};
constexpr auto CLayout = xdlops_gemm.GetOutputLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor( const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc, out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM / 8, 2, 4)), make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
......
...@@ -26,7 +26,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -26,7 +26,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto XdlopsGemm = XdlopsGemm_t<float, MPerWave, NPerWave, KPerWave>{}; static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPerWave>{};
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
...@@ -35,16 +35,16 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -35,16 +35,16 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t MWaves = MPerBlock / MPerWave; static constexpr index_t MWaves = MPerBlock / MPerWave;
static constexpr index_t NWaves = NPerBlock / NPerWave; static constexpr index_t NWaves = NPerBlock / NPerWave;
__device__ constexpr auto GetOutputLayout() const { return XdlopsGemm.GetOutputLayout(); } __device__ constexpr auto GetOutputLayout() const { return xdlops_gemm.GetOutputLayout(); }
__device__ constexpr auto GetNumBlks() const __device__ constexpr auto GetNumBlks() const
{ {
return XdlopsGemm.GetOutputLayout().GetNumBlks(); return xdlops_gemm.GetOutputLayout().GetNumBlks();
} }
__device__ constexpr auto GetBlkSize() const __device__ constexpr auto GetBlkSize() const
{ {
return XdlopsGemm.GetOutputLayout().GetBlkSize(); return xdlops_gemm.GetOutputLayout().GetBlkSize();
} }
__device__ static auto CalculateAThreadOriginDataIndex() __device__ static auto CalculateAThreadOriginDataIndex()
...@@ -75,7 +75,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -75,7 +75,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t waveId = get_thread_local_1d_id() / WaveSize; const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = XdlopsGemm.GetBeginOfThreadBlk(blk_i); const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(blk_i);
const index_t row = (waveId / NWaves) * AStride + thread_mtx_on_blk.row; const index_t row = (waveId / NWaves) * AStride + thread_mtx_on_blk.row;
const index_t col = (waveId % NWaves) * BStride + thread_mtx_on_blk.col; const index_t col = (waveId % NWaves) * BStride + thread_mtx_on_blk.col;
...@@ -127,7 +127,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -127,7 +127,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple(I0, I0), make_tuple(I0, I0),
b_thread_buf); b_thread_buf);
XdlopsGemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf);
}); });
} }
......
...@@ -333,7 +333,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -333,7 +333,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{} // Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); //.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
vector_type<float, 64> c_thread_buf; constexpr auto c_vec_size = MPerBlock * NPerBlock / BlockSize;
vector_type<float, c_vec_size> c_thread_buf;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
...@@ -466,15 +468,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -466,15 +468,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
{ {
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout(); constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t K0 = OutputLayout.M1(); constexpr index_t M0 = OutputLayout.M1();
constexpr index_t K1 = OutputLayout.N1(); constexpr index_t M1 = OutputLayout.N1();
constexpr index_t K2 = OutputLayout.M0(); constexpr index_t M2 = OutputLayout.M0();
// static_assert(K0 == 4 && K1 == 2 && K2 == 4, ""); // static_assert(M0 == 4 && M1 == 2 && M2 == 4, "");
constexpr auto c_m0_m1_m2_n_thread_desc = constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2( make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<K0>{}, Number<1>{}, Number<K2>{}, Number<1>{})); make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
constexpr index_t BlkSize = OutputLayout.GetBlkSize(); constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks(); constexpr index_t NumBlks = OutputLayout.GetNumBlks();
...@@ -508,16 +510,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -508,16 +510,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatC, FloatC,
decltype(c_m0_m1_m2_n_thread_desc), decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc), decltype(c_m0_m1_m2_n_global_desc),
Sequence<K0, 1, K2, 1>, Sequence<M0, 1, M2, 1>,
Sequence<0, 1, 2, 3>, // CThreadTransferSrcDstAccessOrder, Sequence<0, 1, 2, 3>, // CThreadTransferSrcDstAccessOrder,
3, // CThreadTransferSrcDstVectorDim, 3, // CThreadTransferSrcDstVectorDim,
1, // CThreadTransferDstScalarPerVector, 1, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{c_m0_m1_m2_n_global_desc, true>{c_m0_m1_m2_n_global_desc,
make_multi_index(k_thread_data_on_global / (K2 * K1), make_multi_index(k_thread_data_on_global / (M2 * M1),
k_thread_data_on_global % (K2 * K1) / K2, k_thread_data_on_global % (M2 * M1) / M2,
k_thread_data_on_global % K2, k_thread_data_on_global % M2,
b_thread_data_on_global)} b_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc, .Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
......
...@@ -53,7 +53,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32> ...@@ -53,7 +53,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops>::run(a, b, reg_c); return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -74,19 +74,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32> ...@@ -74,19 +74,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
static constexpr index_t cycles = 64; static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1; static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const float*>(a); return intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_32x32x2f32(p_a, p_b, reg_c);
} }
}; };
...@@ -548,7 +539,7 @@ struct xdlops_info ...@@ -548,7 +539,7 @@ struct xdlops_info
}; };
template <class data_type, index_t MPerWave, index_t NPerWave, index_t KPerWave> template <class data_type, index_t MPerWave, index_t NPerWave, index_t KPerWave>
struct XdlopsGemm_t struct XdlopsGemm
{ {
struct MatrixIndex struct MatrixIndex
{ {
...@@ -561,7 +552,7 @@ struct XdlopsGemm_t ...@@ -561,7 +552,7 @@ struct XdlopsGemm_t
return (MPerXdlops * NPerXdlops) / (mfma_type.m * mfma_type.n); return (MPerXdlops * NPerXdlops) / (mfma_type.m * mfma_type.n);
} }
__device__ constexpr XdlopsGemm_t() __host__ __device__ constexpr XdlopsGemm()
{ {
static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 || static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
NPerXdlops == 64, NPerXdlops == 64,
...@@ -849,10 +840,10 @@ struct XdlopsGemm_t ...@@ -849,10 +840,10 @@ struct XdlopsGemm_t
struct OutputLayout struct OutputLayout
{ {
__device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; } __host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; }
__device__ static constexpr index_t M0() { return mfma_type.group_size; } __host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; }
__device__ static constexpr index_t N1() { return mfma_type.num_input_blks; } __host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; }
__device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; } __host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; }
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; } __device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; }
...@@ -867,7 +858,7 @@ struct XdlopsGemm_t ...@@ -867,7 +858,7 @@ struct XdlopsGemm_t
} }
}; };
__device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; } __host__ __device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; }
}; };
} // namespace ck } // namespace ck
......
...@@ -241,7 +241,7 @@ template <> ...@@ -241,7 +241,7 @@ template <>
struct intrin_mfma_f32_32x32x1f32<64, 64> struct intrin_mfma_f32_32x32x1f32<64, 64>
{ {
__device__ static void __device__ static void
run(const float& reg_a, const float& reg_b, vector_type<float, 64>& reg_c) Run(const float& reg_a, const float& reg_b, vector_type<float, 64>& reg_c)
{ {
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
...@@ -272,12 +272,19 @@ struct intrin_mfma_f32_32x32x1f32<64, 64> ...@@ -272,12 +272,19 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
//} //}
//}; //};
__device__ c_vec16_1_t::VecType template <index_t MPerWave, index_t NPerWave>
intrin_mfma_f32_32x32x2f32(const float* reg_a, const float* reg_b, c_vec16_1_t::VecType reg_c) struct intrin_mfma_f32_32x32x2f32;
template <>
struct intrin_mfma_f32_32x32x2f32<32, 32>
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); __device__ static void
return reg_c; Run(const float& reg_a, const float& reg_b, vector_type<float, 16>& reg_c)
} {
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
}
};
__device__ c_vec4_1_t::VecType __device__ c_vec4_1_t::VecType
intrin_mfma_f32_16x16x4f32(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c) intrin_mfma_f32_16x16x4f32(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c)
......
...@@ -71,13 +71,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -71,13 +71,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2( const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(OutDesc::GetLengths())); sequence_to_tuple_of_number(OutDesc::GetLengths()));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
// b thread copy 4x1 #if 0
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 64; constexpr index_t GemmMPerBlock = 64;
...@@ -101,13 +101,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -101,13 +101,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#else
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 32;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmKPerWave = 2;
constexpr index_t GemmM1 = GemmMPerWave; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
constexpr index_t GemmN1 = GemmNPerWave; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif
const auto descs = const auto descs =
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<GemmMPerBlock, transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<GemmMPerBlock,
GemmNPerBlock>( GemmNPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPerWave>(
wei_k_c_y_x_desc, wei_k_c_y_x_desc,
in_n_c_hi_wi_desc, in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc, out_n_k_ho_wo_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment