Commit c5bbc608 authored by zqh-wz's avatar zqh-wz Committed by LeiWang1999
Browse files

[Bugfix] Fix mismatch of shared memory layout and mma atom on Hopper (#224)



* add test for issue 101

* use ss_smem_selector from cutlass

* fix mismatch between smem layout and mma

* only fix for sm90

* Add CUDA requirements to GEMM thread tests

* lint fix

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 3de9f13c
...@@ -132,7 +132,7 @@ Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, ...@@ -132,7 +132,7 @@ Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int element_size) { const int element_size) {
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
// ICHECK(block_n == warp_n); // ICHECK(block_n == warp_n);
ICHECK(warp_m % 16 == 0); ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
auto warp_layout = makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false, auto warp_layout = makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false,
false); // 16 x N (1 warp) false); // 16 x N (1 warp)
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n},
...@@ -478,24 +478,24 @@ Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, ...@@ -478,24 +478,24 @@ Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
return makeGemmABLayoutPadded(stride, continuous, 16); return makeGemmABLayoutPadded(stride, continuous, 16);
} }
Layout makeGemmABLayout(int stride, int continuous, int element_size, Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int kfactor) { int element_size, int kfactor) {
if (element_size == 64) { if (element_size == 64) {
if (kfactor == 1 && continuous % 16 == 0) // float64 KxN if (kfactor == 1 && continuity % 16 == 0) // float64 KxN
return makeGemmABLayoutF64_Kouter(stride, continuous); return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
if (kfactor == 2 && continuous % 16 == 0) // float64 NxK if (kfactor == 2 && continuity % 16 == 0) // float64 NxK
return makeGemmABLayoutF64_Kinner(stride, continuous); return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous);
return makeGemmABLayoutPadded(stride, continuous, element_size); return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
} }
int vector_size = 128 / element_size; int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN if (kfactor == 1 && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(stride, continuous, element_size); return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
else if (continuous % (vector_size * 8) == 0) else if (continuity % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(stride, continuous, element_size); return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (continuous % (vector_size * 4) == 0) else if (continuity % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(stride, continuous, element_size); return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else { else {
return makeGemmABLayoutPadded(stride, continuous, element_size); return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
} }
} }
......
...@@ -448,7 +448,7 @@ TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var") ...@@ -448,7 +448,7 @@ TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var")
TVM_REGISTER_GLOBAL("tl.make_swizzled_layout") TVM_REGISTER_GLOBAL("tl.make_swizzled_layout")
.set_body_typed([](int stride, int continuous, int element_size) { .set_body_typed([](int stride, int continuous, int element_size) {
return makeGemmABLayout(stride, continuous, element_size, 0); return makeGemmABLayout(stride, continuous, continuous, element_size, 0);
}); });
} // namespace tl } // namespace tl
......
...@@ -150,8 +150,8 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, ...@@ -150,8 +150,8 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
// Default Memory Layout // Default Memory Layout
Layout makeGemmLayoutLinear(int stride, int continuous); Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size); Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
Layout makeGemmABLayout(int stride, int continuous, int element_size, Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int kfactor); int element_size, int kfactor);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kfactor); int kfactor);
......
...@@ -186,9 +186,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -186,9 +186,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
results.Set(C, fragment); results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]), const int64_t mat_stride = *as_const_int(A->shape[0]);
*as_const_int(A->shape[1]), const int64_t mat_continuous = *as_const_int(A->shape[1]);
A->dtype.bits(), trans_A ? 1 : 2)); results.Set(A,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false); ICHECK(trans_A == false);
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
...@@ -197,9 +199,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -197,9 +199,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0); ICHECK(0);
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (B.scope() == "shared" || B.scope() == "shared.dyn") {
results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]), const int64_t mat_stride = *as_const_int(B->shape[0]);
*as_const_int(B->shape[1]), const int64_t mat_continuous = *as_const_int(B->shape[1]);
B->dtype.bits(), trans_B ? 2 : 1)); results.Set(B,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1));
} else if (B.scope() == "local.fragment") { } else if (B.scope() == "local.fragment") {
ICHECK(trans_B == false); ICHECK(trans_B == false);
results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n)); results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
...@@ -222,8 +226,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -222,8 +226,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]), const int64_t mat_stride = *as_const_int(A->shape[0]);
*as_const_int(A->shape[1]), const int64_t mat_continuous = *as_const_int(A->shape[1]);
const int64_t continuity =
trans_A ? mat_continuous / (warp_m / 4) : mat_continuous;
results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, continuity,
A->dtype.bits(), trans_A ? 1 : 2)); A->dtype.bits(), trans_A ? 1 : 2));
} else { } else {
ICHECK(trans_A == false); ICHECK(trans_A == false);
...@@ -231,8 +238,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -231,8 +238,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
A->dtype.bits())); A->dtype.bits()));
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (B.scope() == "shared" || B.scope() == "shared.dyn") {
results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]), const int64_t mat_stride = *as_const_int(B->shape[0]);
*as_const_int(B->shape[1]), const int64_t mat_continuous = *as_const_int(B->shape[1]);
const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n;
results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1)); B->dtype.bits(), trans_B ? 2 : 1));
} else { } else {
ICHECK(0) << "WGMMA only support B in shared."; ICHECK(0) << "WGMMA only support B in shared.";
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <cute/atom/mma_atom.hpp> #include <cute/atom/mma_atom.hpp>
#include <cutlass/arch/barrier.h> #include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include "common.h" #include "common.h"
...@@ -15,65 +16,8 @@ namespace cute { ...@@ -15,65 +16,8 @@ namespace cute {
using namespace SM90; using namespace SM90;
namespace tl_wgmma { namespace tl_wgmma {
template <GMMA::Major major, class ElementType, class BLK_MN, class BLK_K>
CUTE_HOST_DEVICE constexpr auto ss_smem_selector() { using namespace cutlass::gemm::collective::detail; // ss_smem_selector
auto BLK_MN0 = size<0>(BLK_MN{});
auto BLK_K0 = size<0>(BLK_K{});
static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8.");
static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8.");
if constexpr (major == GMMA::Major::MN) {
if constexpr (BLK_MN0 %
size<0>(GMMA::Layout_MN_SW128_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_SW128_Atom<ElementType>{};
} else if constexpr (BLK_MN0 %
size<0>(
GMMA::Layout_MN_SW64_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_SW64_Atom<ElementType>{};
} else if constexpr (BLK_MN0 %
size<0>(
GMMA::Layout_MN_SW32_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_SW32_Atom<ElementType>{};
} else if constexpr (BLK_MN0 %
size<0>(
GMMA::Layout_MN_INTER_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_INTER_Atom<ElementType>{};
} else {
static_assert(
BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{}) == 0,
"BLK_MN0 must be a multiple of "
"size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{})");
}
} else if constexpr (major == GMMA::Major::K) {
if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_SW128_Atom<ElementType>{};
} else if constexpr (BLK_K0 %
size<1>(GMMA::Layout_K_SW64_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_SW64_Atom<ElementType>{};
} else if constexpr (BLK_K0 %
size<1>(GMMA::Layout_K_SW32_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_SW32_Atom<ElementType>{};
} else if constexpr (BLK_K0 %
size<1>(
GMMA::Layout_K_INTER_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_INTER_Atom<ElementType>{};
} else {
static_assert(
BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{}) == 0,
"BLK_K0 must be a multiple of "
"size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{})");
}
}
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw, bool trans_B, typename A_type_raw, typename B_type_raw,
...@@ -92,9 +36,11 @@ public: ...@@ -92,9 +36,11 @@ public:
trans_B ? GMMA::Major::K : GMMA::Major::MN; trans_B ? GMMA::Major::K : GMMA::Major::MN;
using SmemLayoutAtomA = using SmemLayoutAtomA =
decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>()); decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M / (num_warp_m / 4)>,
Int<K>>());
using SmemLayoutAtomB = using SmemLayoutAtomB =
decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>()); decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N / num_warp_n>,
Int<K>>());
using SmemLayoutA = decltype(tile_to_shape( using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{}, SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
...@@ -113,9 +59,10 @@ public: ...@@ -113,9 +59,10 @@ public:
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)), Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{}); SmemLayoutB{});
auto tiled_mma = make_tiled_mma( auto tiled_mma = make_tiled_mma(
GMMA::ss_op_selector<A_type, B_type, C_type, GMMA::ss_op_selector<
Shape<Int<M>, Int<N / num_warp_n>, Int<K>>, A_type, B_type, C_type,
GmmaMajorA, GmmaMajorB>(), Shape<Int<M / (num_warp_m / 4)>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{}); Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
auto thr_mma = tiled_mma.get_thread_slice(tid); auto thr_mma = tiled_mma.get_thread_slice(tid);
...@@ -165,9 +112,10 @@ public: ...@@ -165,9 +112,10 @@ public:
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)), Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{}); SmemLayoutB{});
auto tiled_mma = make_tiled_mma( auto tiled_mma = make_tiled_mma(
GMMA::rs_op_selector<A_type, B_type, C_type, GMMA::rs_op_selector<
Shape<Int<M>, Int<N / num_warp_n>, Int<K>>, A_type, B_type, C_type,
GmmaMajorA, GmmaMajorB>(), Shape<Int<M / (num_warp_m / 4)>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{}); Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
auto thr_mma = tiled_mma.get_thread_slice(tid); auto thr_mma = tiled_mma.get_thread_slice(tid);
......
import torch
import tilelang
import tilelang.testing
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, threads, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, policy=T.GemmWarpPolicy.FullCol)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_threads_test(threads, M=1024, N=192, K=1024, block_M=64, block_N=192, block_K=32):
func = matmul(M, N, K, block_M, block_N, block_K, threads)
jit_kernel = tilelang.compile(func, out_idx=-1, target="cuda")
torch.manual_seed(0)
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
ref_c = a @ b
c = jit_kernel(a, b)
tilelang.testing.torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_gemm_threads_2wgs():
run_gemm_threads_test(128 * 2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_gemm_threads_4wgs():
run_gemm_threads_test(128 * 4)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment