Commit 701e9234 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Enhancement] Add zero initialization option to GEMM operations (#246)

* [Enhancement] Add zero initialization option to GEMM operations

- Introduced a new `zero_init` parameter to the GEMM function, allowing for optional zero initialization of the accumulator.
- Updated the GEMM implementation across various CUDA architectures to support the new parameter.
- Modified the Python interface for GEMM to include the `zero_init` argument, enhancing flexibility in kernel execution.
- Ensured compatibility with existing functionality while improving initialization control for performance optimization.

* rename zero_init to clear_accum

* lint
parent 71537ba5
...@@ -46,14 +46,15 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { ...@@ -46,14 +46,15 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
N = args[6].as<IntImm>().value()->value; N = args[6].as<IntImm>().value()->value;
K = args[7].as<IntImm>().value()->value; K = args[7].as<IntImm>().value()->value;
policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value); policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
if (args.size() > 9) { clear_accum = args[9].as<Bool>().value();
kPack = args[9].as<IntImm>().value()->value; if (args.size() > 10) {
kPack = args[10].as<IntImm>().value()->value;
if (kPack != 1 && kPack != 2) { if (kPack != 1 && kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2"; ICHECK(false) << "kPack must be 1 or 2";
} }
} }
if (args.size() > 10) { if (args.size() > 11) {
wg_wait = args[10].as<IntImm>().value()->value; wg_wait = args[11].as<IntImm>().value()->value;
} }
} }
...@@ -132,6 +133,7 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -132,6 +133,7 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ss << op_name << "<" << M << ", " << N << ", " << K << ", "; ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
ss << warp_m << ", " << warp_n << ", "; ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B; ss << trans_A << ", " << trans_B;
ss << ", " << clear_accum;
if (TargetIsCDNA(T.target)) { if (TargetIsCDNA(T.target)) {
// for cdna gemm, we need to specify kPack // for cdna gemm, we need to specify kPack
ss << ", " << kPack; ss << ", " << kPack;
......
...@@ -38,6 +38,7 @@ private: ...@@ -38,6 +38,7 @@ private:
tir::Buffer A, B, C; tir::Buffer A, B, C;
bool trans_A, trans_B; bool trans_A, trans_B;
int M, N, K; int M, N, K;
bool clear_accum = false;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
int kPack = 1; int kPack = 1;
......
...@@ -61,8 +61,8 @@ struct DispatchSharedMemoryLayoutB<half_t, false, N, K, ...@@ -61,8 +61,8 @@ struct DispatchSharedMemoryLayoutB<half_t, false, N, K,
}; };
template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A, template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw, bool trans_B, bool clear_accum, typename A_type_raw,
typename C_type_raw> typename B_type_raw, typename C_type_raw>
class GemmTensorOp { class GemmTensorOp {
public: public:
using A_type = A_type_raw; using A_type = A_type_raw;
...@@ -125,6 +125,9 @@ public: ...@@ -125,6 +125,9 @@ public:
IteratorB iter_B(ref_B, lane_id); IteratorB iter_B(ref_B, lane_id);
iter_A.add_tile_offset({warp_idx_m, 0}); iter_A.add_tile_offset({warp_idx_m, 0});
iter_B.add_tile_offset({0, warp_idx_n}); iter_B.add_tile_offset({0, warp_idx_n});
if constexpr (clear_accum) {
accum.clear();
}
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < kKgroups; ++k) { for (int k = 0; k < kKgroups; ++k) {
iter_A.load(frag_A); iter_A.load(frag_A);
...@@ -143,6 +146,9 @@ public: ...@@ -143,6 +146,9 @@ public:
const TensorRefB ref_B((B_type *)pB, stride_B); const TensorRefB ref_B((B_type *)pB, stride_B);
IteratorB iter_B(ref_B, lane_id); IteratorB iter_B(ref_B, lane_id);
iter_B.add_tile_offset({0, warp_idx_n}); iter_B.add_tile_offset({0, warp_idx_n});
if constexpr (clear_accum) {
accum.clear();
}
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < kKgroups; ++k) { for (int k = 0; k < kKgroups; ++k) {
iter_B.load(frag_B); iter_B.load(frag_B);
...@@ -155,10 +161,11 @@ public: ...@@ -155,10 +161,11 @@ public:
namespace tl { namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type> bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A, using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>; trans_B, clear_accum, A_type, B_type, C_type>;
using FragmentC = typename MMA::FragmentC; using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32; int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32; int lane_id = threadIdx.x % 32;
...@@ -167,10 +174,11 @@ CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { ...@@ -167,10 +174,11 @@ CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type> bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A, using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>; trans_B, clear_accum, A_type, B_type, C_type>;
using FragmentA = typename MMA::FragmentA; using FragmentA = typename MMA::FragmentA;
using FragmentC = typename MMA::FragmentC; using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32; int warp_id = threadIdx.x / 32;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Licensed under the MIT License. // Licensed under the MIT License.
#pragma once #pragma once
#include <cute/algorithm/clear.hpp>
#include <cute/arch/mma_sm80.hpp> #include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_atom.hpp> #include <cute/atom/mma_atom.hpp>
#include <cute/underscore.hpp> #include <cute/underscore.hpp>
...@@ -183,8 +184,8 @@ struct OperandTraits<64, N, K, false, ...@@ -183,8 +184,8 @@ struct OperandTraits<64, N, K, false,
}; };
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw, bool trans_B, bool clear_accum, typename A_type_raw,
typename C_type_raw> typename B_type_raw, typename C_type_raw>
class GemmTensorOp { class GemmTensorOp {
public: public:
using A_type = using A_type =
...@@ -250,6 +251,9 @@ public: ...@@ -250,6 +251,9 @@ public:
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)), make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{})); partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
if constexpr (clear_accum) {
clear(acc);
}
// when layout is KxN and n_warp is 1, there seem to be a bug, use this as a // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
// workaround // workaround
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
...@@ -284,6 +288,9 @@ public: ...@@ -284,6 +288,9 @@ public:
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)), make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{})); partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
if constexpr (clear_accum) {
clear(acc);
}
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0)); copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
CUTE_UNROLL CUTE_UNROLL
...@@ -317,6 +324,9 @@ public: ...@@ -317,6 +324,9 @@ public:
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)), make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{})); partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
if constexpr (clear_accum) {
clear(acc);
}
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0)); copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
CUTE_UNROLL CUTE_UNROLL
...@@ -334,26 +344,29 @@ public: ...@@ -334,26 +344,29 @@ public:
namespace tl { namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type> bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>; trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body(pA, pB, accum); MMA::body(pA, pB, accum);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type> bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>; trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum); MMA::body_rs(pA, pB, accum);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type> bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) { CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>; trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum); MMA::body_sr(pA, pB, accum);
} }
......
...@@ -20,8 +20,8 @@ namespace tl_wgmma { ...@@ -20,8 +20,8 @@ namespace tl_wgmma {
using namespace cutlass::gemm::collective::detail; // ss_smem_selector using namespace cutlass::gemm::collective::detail; // ss_smem_selector
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw, bool trans_B, bool clear_accum, typename A_type_raw,
typename C_type_raw> typename B_type_raw, typename C_type_raw>
class GemmTensorOp { class GemmTensorOp {
public: public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value, using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
...@@ -79,6 +79,9 @@ public: ...@@ -79,6 +79,9 @@ public:
warpgroup_fence_operand(acc); warpgroup_fence_operand(acc);
warpgroup_arrive(); warpgroup_arrive();
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// warpgroup_arrive(); // warpgroup_arrive();
...@@ -132,6 +135,9 @@ public: ...@@ -132,6 +135,9 @@ public:
warpgroup_fence_operand(tCrA); warpgroup_fence_operand(tCrA);
warpgroup_fence_operand(acc); warpgroup_fence_operand(acc);
warpgroup_arrive(); warpgroup_arrive();
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// warpgroup_arrive(); // warpgroup_arrive();
...@@ -335,8 +341,8 @@ struct OperandTraits<64, N, K, false, ...@@ -335,8 +341,8 @@ struct OperandTraits<64, N, K, false,
}; };
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw, bool trans_B, bool clear_accum, typename A_type_raw,
typename C_type_raw> typename B_type_raw, typename C_type_raw>
class GemmTensorOp { class GemmTensorOp {
public: public:
using A_type = using A_type =
...@@ -406,6 +412,9 @@ public: ...@@ -406,6 +412,9 @@ public:
// workaround // workaround
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
CUTE_UNROLL CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) { for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
...@@ -437,6 +446,9 @@ public: ...@@ -437,6 +446,9 @@ public:
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{})); partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0)); copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
CUTE_UNROLL CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) { for (int k = 0; k < size<2>(tCrA); ++k) {
...@@ -470,6 +482,9 @@ public: ...@@ -470,6 +482,9 @@ public:
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{})); partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0)); copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
CUTE_UNROLL CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) { for (int k = 0; k < size<2>(tCrA); ++k) {
...@@ -490,64 +505,67 @@ namespace tl { ...@@ -490,64 +505,67 @@ namespace tl {
namespace tl_mma { namespace tl_mma {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type> bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>; trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body(pA, pB, accum); MMA::body(pA, pB, accum);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type> bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>; trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum); MMA::body_rs(pA, pB, accum);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type> bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) { CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA = using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>; trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum); MMA::body_sr(pA, pB, accum);
} }
} // namespace tl_mma } // namespace tl_mma
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool use_wgmma = true, int wg_wait = 0, typename A_type, bool trans_B, bool clear_accum = false, bool use_wgmma = true,
typename B_type, typename C_type> int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
if constexpr (use_wgmma) { if constexpr (use_wgmma) {
using MMA = using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_A, trans_B, clear_accum,
trans_B, A_type, B_type, C_type>; A_type, B_type, C_type>;
MMA::body<wg_wait>(pA, pB, accum); MMA::body<wg_wait>(pA, pB, accum);
} else { } else {
using MMA = using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_A, trans_B, clear_accum,
trans_B, A_type, B_type, C_type>; A_type, B_type, C_type>;
MMA::body(pA, pB, accum); MMA::body(pA, pB, accum);
} }
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool use_wgmma = true, int wg_wait = 0, typename A_type, bool trans_B, bool clear_accum = false, bool use_wgmma = true,
typename B_type, typename C_type> int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
if constexpr (use_wgmma) { if constexpr (use_wgmma) {
using MMA = using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_A, trans_B, clear_accum,
trans_B, A_type, B_type, C_type>; A_type, B_type, C_type>;
MMA::body_rs<wg_wait>(pA, pB, accum); MMA::body_rs<wg_wait>(pA, pB, accum);
} else { } else {
using MMA = using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_A, trans_B, clear_accum,
trans_B, A_type, B_type, C_type>; A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum); MMA::body_rs(pA, pB, accum);
} }
} }
......
...@@ -13,6 +13,7 @@ def gemm( ...@@ -13,6 +13,7 @@ def gemm(
transpose_A: bool = False, transpose_A: bool = False,
transpose_B: bool = False, transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square, policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1, k_pack: int = 1,
wg_wait: int = 0, wg_wait: int = 0,
): ):
...@@ -41,6 +42,7 @@ def gemm( ...@@ -41,6 +42,7 @@ def gemm(
N, N,
K, K,
policy, policy,
clear_accum,
k_pack, k_pack,
wg_wait, wg_wait,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment