"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "5a41d69b2ab7d133a27e4f6d5666982c73a5b5ad"
Commit 701e9234 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Enhancement] Add zero initialization option to GEMM operations (#246)

* [Enhancement] Add zero initialization option to GEMM operations

- Introduced a new `zero_init` parameter to the GEMM function, allowing for optional zero initialization of the accumulator.
- Updated the GEMM implementation across various CUDA architectures to support the new parameter.
- Modified the Python interface for GEMM to include the `zero_init` argument, enhancing flexibility in kernel execution.
- Ensured compatibility with existing functionality while improving initialization control for performance optimization.

* rename zero_init to clear_accum

* lint
parent 71537ba5
......@@ -46,14 +46,15 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
N = args[6].as<IntImm>().value()->value;
K = args[7].as<IntImm>().value()->value;
policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
if (args.size() > 9) {
kPack = args[9].as<IntImm>().value()->value;
clear_accum = args[9].as<Bool>().value();
if (args.size() > 10) {
kPack = args[10].as<IntImm>().value()->value;
if (kPack != 1 && kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 10) {
wg_wait = args[10].as<IntImm>().value()->value;
if (args.size() > 11) {
wg_wait = args[11].as<IntImm>().value()->value;
}
}
......@@ -132,6 +133,7 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B;
ss << ", " << clear_accum;
if (TargetIsCDNA(T.target)) {
// for cdna gemm, we need to specify kPack
ss << ", " << kPack;
......
......@@ -38,6 +38,7 @@ private:
tir::Buffer A, B, C;
bool trans_A, trans_B;
int M, N, K;
bool clear_accum = false;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
......
......@@ -61,8 +61,8 @@ struct DispatchSharedMemoryLayoutB<half_t, false, N, K,
};
template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type = A_type_raw;
......@@ -125,6 +125,9 @@ public:
IteratorB iter_B(ref_B, lane_id);
iter_A.add_tile_offset({warp_idx_m, 0});
iter_B.add_tile_offset({0, warp_idx_n});
if constexpr (clear_accum) {
accum.clear();
}
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < kKgroups; ++k) {
iter_A.load(frag_A);
......@@ -143,6 +146,9 @@ public:
const TensorRefB ref_B((B_type *)pB, stride_B);
IteratorB iter_B(ref_B, lane_id);
iter_B.add_tile_offset({0, warp_idx_n});
if constexpr (clear_accum) {
accum.clear();
}
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < kKgroups; ++k) {
iter_B.load(frag_B);
......@@ -155,10 +161,11 @@ public:
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
trans_B, clear_accum, A_type, B_type, C_type>;
using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
......@@ -167,10 +174,11 @@ CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
trans_B, clear_accum, A_type, B_type, C_type>;
using FragmentA = typename MMA::FragmentA;
using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32;
......
......@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#pragma once
#include <cute/algorithm/clear.hpp>
#include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/underscore.hpp>
......@@ -183,8 +184,8 @@ struct OperandTraits<64, N, K, false,
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type =
......@@ -250,6 +251,9 @@ public:
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
if constexpr (clear_accum) {
clear(acc);
}
// when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
// workaround
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
......@@ -284,6 +288,9 @@ public:
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
if constexpr (clear_accum) {
clear(acc);
}
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
CUTE_UNROLL
......@@ -317,6 +324,9 @@ public:
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
if constexpr (clear_accum) {
clear(acc);
}
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
CUTE_UNROLL
......@@ -334,26 +344,29 @@ public:
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}
......
......@@ -20,8 +20,8 @@ namespace tl_wgmma {
using namespace cutlass::gemm::collective::detail; // ss_smem_selector
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
......@@ -79,6 +79,9 @@ public:
warpgroup_fence_operand(acc);
warpgroup_arrive();
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// warpgroup_arrive();
......@@ -132,6 +135,9 @@ public:
warpgroup_fence_operand(tCrA);
warpgroup_fence_operand(acc);
warpgroup_arrive();
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// warpgroup_arrive();
......@@ -335,8 +341,8 @@ struct OperandTraits<64, N, K, false,
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type =
......@@ -406,6 +412,9 @@ public:
// workaround
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
......@@ -437,6 +446,9 @@ public:
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) {
......@@ -470,6 +482,9 @@ public:
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) {
......@@ -490,64 +505,67 @@ namespace tl {
namespace tl_mma {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
trans_B, clear_accum, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}
} // namespace tl_mma
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool use_wgmma = true, int wg_wait = 0, typename A_type,
typename B_type, typename C_type>
bool trans_B, bool clear_accum = false, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
if constexpr (use_wgmma) {
using MMA =
cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
trans_A, trans_B, clear_accum,
A_type, B_type, C_type>;
MMA::body<wg_wait>(pA, pB, accum);
} else {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
trans_A, trans_B, clear_accum,
A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool use_wgmma = true, int wg_wait = 0, typename A_type,
typename B_type, typename C_type>
bool trans_B, bool clear_accum = false, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
if constexpr (use_wgmma) {
using MMA =
cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
trans_A, trans_B, clear_accum,
A_type, B_type, C_type>;
MMA::body_rs<wg_wait>(pA, pB, accum);
} else {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
trans_A, trans_B, clear_accum,
A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}
}
......
......@@ -13,6 +13,7 @@ def gemm(
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
):
......@@ -41,6 +42,7 @@ def gemm(
N,
K,
policy,
clear_accum,
k_pack,
wg_wait,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment