Commit 33937683 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Update GEMM layout and operand traits for improved CUDA compatibility (#500)

* [Enhancement] Improve GEMM layout function and documentation

* Added detailed documentation for the makeGemmABLayout function, explaining parameters and layout selection strategies.
* Updated the layout selection logic to use mat_continuous consistently, enhancing clarity and correctness in memory layout calculations.
* Adjusted the InferLayout method to reflect changes in the layout function, ensuring accurate matrix dimension handling for transposed cases.

* lint fix

* [Refactor] Update GEMM layout and operand traits for improved CUDA compatibility

* Adjusted the InferLayout method in gemm.cc to include trans_A in fragment creation, enhancing layout inference for transposed matrices.
* Updated OperandTraits in gemm_sm89.h and gemm_sm90.h to change the Copy type from SM75_U16x4_LDSM_N to SM75_U16x4_LDSM_T, optimizing memory access patterns for different warp configurations.
* Enhanced static assertions in gemm_sm90.h to clarify requirements for num_warp_m, ensuring compatibility with Hopper architecture.

* [Refactor] Clean up formatting in GEMM implementation and CUDA templates

* Simplified the formatting of the fragment creation in the InferLayout method of gemm.cc for better readability.
* Adjusted the static assertion message in gemm_sm90.h to enhance clarity regarding the num_warp_m requirement for Hopper architecture.
parent 2837878f
...@@ -248,9 +248,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -248,9 +248,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2)); A->dtype.bits(), trans_A ? 1 : 2));
} else { } else {
ICHECK(trans_A == false); auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
auto fragment = A->dtype.bits(), trans_A);
makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
results.Set(A, fragment->BindThreadRange(thread_range)); results.Set(A, fragment->BindThreadRange(thread_range));
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (B.scope() == "shared" || B.scope() == "shared.dyn") {
......
...@@ -34,11 +34,9 @@ public: ...@@ -34,11 +34,9 @@ public:
trans_B ? GMMA::Major::K : GMMA::Major::MN; trans_B ? GMMA::Major::K : GMMA::Major::MN;
using SmemLayoutAtomA = using SmemLayoutAtomA =
decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M / (num_warp_m / 4)>, decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
Int<K>>());
using SmemLayoutAtomB = using SmemLayoutAtomB =
decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N / num_warp_n>, decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
Int<K>>());
using SmemLayoutA = decltype(tile_to_shape( using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{}, SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
...@@ -47,7 +45,8 @@ public: ...@@ -47,7 +45,8 @@ public:
SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{}, SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{})); conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4"); static_assert(num_warp_m % 4 == 0,
"num_warp_m must be a multiple of 4 for hopper wgmma");
template <int wg_wait = 0> template <int wg_wait = 0>
static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
...@@ -59,7 +58,7 @@ public: ...@@ -59,7 +58,7 @@ public:
auto tiled_mma = make_tiled_mma( auto tiled_mma = make_tiled_mma(
GMMA::ss_op_selector< GMMA::ss_op_selector<
A_type, B_type, C_type, A_type, B_type, C_type,
Shape<Int<M / (num_warp_m / 4)>, Int<N / num_warp_n>, Int<K>>, Shape<Int<4 * M / num_warp_m>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(), GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{}); Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
auto thr_mma = tiled_mma.get_thread_slice(tid); auto thr_mma = tiled_mma.get_thread_slice(tid);
...@@ -93,14 +92,6 @@ public: ...@@ -93,14 +92,6 @@ public:
warpgroup_wait<wg_wait>(); warpgroup_wait<wg_wait>();
} }
warpgroup_fence_operand(acc); warpgroup_fence_operand(acc);
// warpgroup_fence_operand(acc);
// warpgroup_arrive();
// gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);
// warpgroup_commit_batch();
// if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
// warpgroup_fence_operand(acc);
} }
template <int wg_wait = 0> template <int wg_wait = 0>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment