"...composable_kernel_rocm.git" did not exist on "cec69bc3bc200de7e09396579fe33cb153f8afeb"
Commit d7a0a3f9 authored by Jing Zhang's avatar Jing Zhang
Browse files

renaming/comments

parent 2cbb8976
...@@ -247,14 +247,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -247,14 +247,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
} }
private: private:
// A[K, M] // A[K0, M0, M1, M2, K1]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
// B[K, N] // B[K0, N0, N1, N2, K1]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
// C[M, N]
static constexpr auto c_thread_desc_ = static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
......
...@@ -545,7 +545,7 @@ struct MfmaSelector ...@@ -545,7 +545,7 @@ struct MfmaSelector
selected_mfma.k_per_blk; selected_mfma.k_per_blk;
} }
static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; } static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
}; };
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack>
...@@ -708,7 +708,7 @@ struct XdlopsGemm ...@@ -708,7 +708,7 @@ struct XdlopsGemm
static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto mfma_instr = mfma.selected_mfma;
static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto K1PerXdlops = mfma.GetKPerThread(); static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment