Commit 33937683 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Update GEMM layout and operand traits for improved CUDA compatibility (#500)

* [Enhancement] Improve GEMM layout function and documentation

* Added detailed documentation for the makeGemmABLayout function, explaining parameters and layout selection strategies.
* Updated the layout selection logic to use mat_continuous consistently, enhancing clarity and correctness in memory layout calculations.
* Adjusted the InferLayout method to reflect changes in the layout function, ensuring accurate matrix dimension handling for transposed cases.

* lint fix

* [Refactor] Update GEMM layout and operand traits for improved CUDA compatibility

* Adjusted the InferLayout method in gemm.cc to include trans_A in fragment creation, enhancing layout inference for transposed matrices.
* Updated OperandTraits in gemm_sm89.h and gemm_sm90.h to change the Copy type from SM75_U16x4_LDSM_N to SM75_U16x4_LDSM_T, optimizing memory access patterns for different warp configurations.
* Enhanced static assertions in gemm_sm90.h to clarify requirements for num_warp_m, ensuring compatibility with Hopper architecture.

* [Refactor] Clean up formatting in GEMM implementation and CUDA templates

* Simplified the formatting of the fragment creation in the InferLayout method of gemm.cc for better readability.
* Adjusted the static assertion message in gemm_sm90.h to enhance clarity regarding the num_warp_m requirement for Hopper architecture.
parent 2837878f
......@@ -248,9 +248,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2));
} else {
ICHECK(trans_A == false);
auto fragment =
makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A);
results.Set(A, fragment->BindThreadRange(thread_range));
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
......
......@@ -34,11 +34,9 @@ public:
trans_B ? GMMA::Major::K : GMMA::Major::MN;
using SmemLayoutAtomA =
decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M / (num_warp_m / 4)>,
Int<K>>());
decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
using SmemLayoutAtomB =
decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N / num_warp_n>,
Int<K>>());
decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
......@@ -47,7 +45,8 @@ public:
SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4");
static_assert(num_warp_m % 4 == 0,
"num_warp_m must be a multiple of 4 for hopper wgmma");
template <int wg_wait = 0>
static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
......@@ -59,7 +58,7 @@ public:
auto tiled_mma = make_tiled_mma(
GMMA::ss_op_selector<
A_type, B_type, C_type,
Shape<Int<M / (num_warp_m / 4)>, Int<N / num_warp_n>, Int<K>>,
Shape<Int<4 * M / num_warp_m>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
auto thr_mma = tiled_mma.get_thread_slice(tid);
......@@ -93,14 +92,6 @@ public:
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(acc);
// warpgroup_fence_operand(acc);
// warpgroup_arrive();
// gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);
// warpgroup_commit_batch();
// if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
// warpgroup_fence_operand(acc);
}
template <int wg_wait = 0>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment