Commit 9232e7b8 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Update accumulation handling in gemm_sm90.h (#603)

- Replaced the use of `tiled_mma.accumulate_ = GMMA::ScaleOut::Zero` with a call to `clear(acc)` for better clarity and maintainability in the accumulation logic.
- This change enhances the readability of the code by standardizing the approach to clearing accumulation values across multiple sections of the file.
parent 3ca5a4ba
......@@ -413,7 +413,7 @@ public:
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
clear(acc);
}
CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) {
......@@ -446,7 +446,7 @@ public:
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
clear(acc);
}
copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
CUTE_UNROLL
......@@ -481,7 +481,7 @@ public:
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
clear(acc);
}
copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
CUTE_UNROLL
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment