Commit 5dbaf3c2 authored by Jing Zhang's avatar Jing Zhang
Browse files

refactor xdlops, hide c desc

parent 370c9245
...@@ -118,13 +118,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -118,13 +118,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor() __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor()
{ {
///\to-do: hide xdl clayout into xdlops-gemm constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto CXdlopsLayout = xdlops_gemm.GetCXdlopsLayout();
constexpr auto M0 = Number<CXdlopsLayout.M1()>{}; constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M2 = Number<CXdlopsLayout.M0()>{}; constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, I1, M2, I1)); return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N));
} }
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor() __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor()
...@@ -195,7 +196,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -195,7 +196,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type<FloatAB, K1> b_thread_vec; vector_type<FloatAB, K1> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k0) { static_for<0, KPerBlock, xdlops_gemm.KPerXdlops / xdlops_gemm.KPerThread>{}([&](auto k0) {
// read A // read A
a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
make_tuple(k0, I0, I0, I0, I0), make_tuple(k0, I0, I0, I0, I0),
...@@ -212,8 +213,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -212,8 +213,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
using mfma_input_type = using mfma_input_type = typename vector_type<FloatAB, xdlops_gemm.KPerThread>::type;
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_per_blk>::type;
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment