Commit 549416f7 authored by LeiWang1999's avatar LeiWang1999
Browse files

Merge branch 'main' of https://github.com/microsoft/TileLang into main

parents 4d63633a 7fad4e88
...@@ -8,222 +8,245 @@ ...@@ -8,222 +8,245 @@
namespace tl { namespace tl {
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar, TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const* const smem_ptr, int32_t const& crd0) { void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile( asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes" "complete_tx::bytes"
" [%0], [%1, {%3}], [%2];" " [%0], [%1, {%3}], [%2];"
: :
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0) : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
: "memory"); "r"(crd0)
: "memory");
} }
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar, TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1) { void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile( asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes" "complete_tx::bytes"
" [%0], [%1, {%3, %4}], [%2];" " [%0], [%1, {%3, %4}], [%2];"
: :
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1) : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
: "memory"); "r"(crd0), "r"(crd1)
: "memory");
} }
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar, TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1, void const *const smem_ptr, int32_t const &crd0,
int32_t const& crd2) { int32_t const &crd1, int32_t const &crd2) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile( asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes" "complete_tx::bytes"
" [%0], [%1, {%3, %4, %5}], [%2];" " [%0], [%1, {%3, %4, %5}], [%2];"
: :
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2) : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
: "memory"); "r"(crd0), "r"(crd1), "r"(crd2)
: "memory");
} }
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar, TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1, void const *const smem_ptr, int32_t const &crd0,
int32_t const& crd2, int32_t const& crd3) { int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile( asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes" "complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2];" " [%0], [%1, {%3, %4, %5, %6}], [%2];"
: :
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2), : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd3) "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
: "memory"); : "memory");
} }
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar, TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1, void const *const smem_ptr, int32_t const &crd0,
int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) { int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3, int32_t const &crd4) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile( asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes" "complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2];" " [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
: :
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2), : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd3), "r"(crd4) "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
: "memory"); : "memory");
} }
TL_DEVICE void tma_load_im2col(const CUtensorMap& descriptor, uint64_t& smem_mbar, TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
void const* const smem_ptr, int32_t const& coord_c, uint64_t &smem_mbar, void const *const smem_ptr,
int32_t const& coord_w, int32_t const& coord_h, int32_t const &coord_c, int32_t const &coord_w,
int32_t const& coord_n, uint16_t const& offset_w, int32_t const &coord_h, int32_t const &coord_n,
uint16_t const& offset_h) { uint16_t const &offset_w,
uint16_t const &offset_h) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile( asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
"cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" ":complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};" " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};"
: :
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(coord_c), "r"(coord_w), : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(coord_h), "r"(coord_n), "h"(offset_w), "h"(offset_h) "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n),
: "memory"); "h"(offset_w), "h"(offset_h)
: "memory");
} }
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr, TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const& crd0) { void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];" asm volatile(
: "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];"
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0) :
: "memory"); : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0)
: "memory");
} }
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr, TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const& crd0, int32_t const& crd1) { void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];" asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, "
"{%2, %3}], [%1];"
: :
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1) : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1)
: "memory"); : "memory");
} }
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr, TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];" asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4}], [%1];"
: :
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2) : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2)
: "memory"); : "memory");
} }
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr, TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, void const *const smem_ptr, int32_t const &crd0,
int32_t const& crd3) { int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];" asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4, %5}], [%1];"
: :
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "r"(crd3)
: "memory"); : "memory");
} }
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr, TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, void const *const smem_ptr, int32_t const &crd0,
int32_t const& crd3, int32_t const& crd4) { int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3, int32_t const &crd4) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile( asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, "
"cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" "{%2, %3, %4, %5, %6}], [%1];"
: :
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
: "memory"); "r"(crd2), "r"(crd3), "r"(crd4)
: "memory");
} }
TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap& descriptor) { TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory"); asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory");
} }
TL_DEVICE void mbarrier_init(uint64_t& smem_barrier, uint32_t arrive_count) { TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.init.shared.b64 [%1], %0;" : : "r"(arrive_count), "r"(smem_int_ptr)); asm volatile("mbarrier.init.shared.b64 [%1], %0;"
:
: "r"(arrive_count), "r"(smem_int_ptr));
} }
TL_DEVICE void mbarrier_wait(uint64_t& smem_barrier, int phase_bit) { TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile( asm volatile("{\n"
"{\n" ".reg .pred P1;\n"
".reg .pred P1;\n" "LAB_WAIT:\n"
"LAB_WAIT:\n" "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n" "@!P1 bra.uni LAB_WAIT;\n"
"@!P1 bra.uni LAB_WAIT;\n" "}\n" ::"r"(smem_int_ptr),
"}\n" ::"r"(smem_int_ptr), "r"(phase_bit));
"r"(phase_bit));
} }
TL_DEVICE void mbarrier_arrive(uint64_t& smem_barrier) { TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr)); asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr));
} }
TL_DEVICE void mbarrier_expect_tx(uint64_t& smem_barrier, uint32_t transaction_bytes) { TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;" asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;"
: :
: "r"(transaction_bytes), "r"(smem_int_ptr)); : "r"(transaction_bytes), "r"(smem_int_ptr));
} }
TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t& smem_barrier, uint32_t transaction_bytes) { TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;" asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;"
: :
: "r"(transaction_bytes), "r"(smem_int_ptr)); : "r"(transaction_bytes), "r"(smem_int_ptr));
} }
TL_DEVICE void mbarrier_cp_async_arrive(uint64_t& smem_barrier) { TL_DEVICE void mbarrier_cp_async_arrive(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];" : : "r"(smem_int_ptr)); asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];"
:
: "r"(smem_int_ptr));
} }
TL_DEVICE void fence_proxy_async() { asm volatile("fence.proxy.async.shared::cta;" : :); } TL_DEVICE void fence_proxy_async() {
asm volatile("fence.proxy.async.shared::cta;" : :);
}
TL_DEVICE void syncthreads_partial(uint64_t& smem_barrier) { TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint64_t state; uint64_t state;
asm volatile( asm volatile("{\n"
"{\n" ".reg .pred P1;\n"
".reg .pred P1;\n" "mbarrier.arrive.shared.b64 %1, [%0];\n"
"mbarrier.arrive.shared.b64 %1, [%0];\n" "LAB_WAIT:\n"
"LAB_WAIT:\n" "mbarrier.try_wait.shared.b64 P1, [%0], %1;\n"
"mbarrier.try_wait.shared.b64 P1, [%0], %1;\n" "@!P1 bra.uni LAB_WAIT;\n"
"@!P1 bra.uni LAB_WAIT;\n" "}\n"
"}\n" :
: : "r"(smem_int_ptr), "l"(state));
: "r"(smem_int_ptr), "l"(state));
} }
template<uint32_t RegCount> template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
TL_DEVICE void warpgroup_reg_alloc(){ asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) );
} }
template<uint32_t RegCount> template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
TL_DEVICE void warpgroup_reg_dealloc(){ asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) );
} }
} // namespace tl } // namespace tl
\ No newline at end of file \ No newline at end of file
...@@ -13,78 +13,94 @@ using cutlass::gemm::GemmShape; ...@@ -13,78 +13,94 @@ using cutlass::gemm::GemmShape;
// Add 128 bits padding when the last dim is a multiple of 256 bits // Add 128 bits padding when the last dim is a multiple of 256 bits
template <typename T, bool transpose, int M, int K, typename Enable = void> template <typename T, bool transpose, int M, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutA { struct DispatchSharedMemoryLayoutA {
using Layout = typename std::conditional<transpose, cutlass::layout::ColumnMajor, using Layout =
cutlass::layout::RowMajor>::type; typename std::conditional<transpose, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type;
static int constexpr Dim = transpose ? M : K; static int constexpr Dim = transpose ? M : K;
static int constexpr Stride = (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim; static int constexpr Stride =
(Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
}; };
template <typename T, bool transpose, int N, int K, typename Enable = void> template <typename T, bool transpose, int N, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutB { struct DispatchSharedMemoryLayoutB {
using Layout = typename std::conditional<transpose, cutlass::layout::ColumnMajor, using Layout =
cutlass::layout::RowMajor>::type; typename std::conditional<transpose, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type;
static int constexpr Dim = transpose ? K : N; static int constexpr Dim = transpose ? K : N;
static int constexpr Stride = (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim; static int constexpr Stride =
(Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
}; };
// Partial specialization for half_t // Partial specialization for half_t
template <int M, int K> template <int M, int K>
struct DispatchSharedMemoryLayoutA<half_t, true, M, K, typename std::enable_if<M % 64 == 0>::type> { struct DispatchSharedMemoryLayoutA<half_t, true, M, K,
using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous<16>; typename std::enable_if<M % 64 == 0>::type> {
using Layout =
cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous<16>;
static int constexpr Stride = M; static int constexpr Stride = M;
}; };
template <int M, int K> template <int M, int K>
struct DispatchSharedMemoryLayoutA<half_t, false, M, K> { struct DispatchSharedMemoryLayoutA<half_t, false, M, K> {
using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, K>; using Layout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
static int constexpr Stride = M; static int constexpr Stride = M;
}; };
template <int N, int K> template <int N, int K> struct DispatchSharedMemoryLayoutB<half_t, true, N, K> {
struct DispatchSharedMemoryLayoutB<half_t, true, N, K> { using Layout =
using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, K>; cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
static int constexpr Stride = N; static int constexpr Stride = N;
}; };
template <int N, int K> template <int N, int K>
struct DispatchSharedMemoryLayoutB<half_t, false, N, K, struct DispatchSharedMemoryLayoutB<half_t, false, N, K,
typename std::enable_if<N % 64 == 0>::type> { typename std::enable_if<N % 64 == 0>::type> {
using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous<16>; using Layout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous<16>;
static int constexpr Stride = N; static int constexpr Stride = N;
}; };
template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type_raw, typename B_type_raw, typename C_type_raw> bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp { class GemmTensorOp {
public: public:
using A_type = A_type_raw; using A_type = A_type_raw;
using B_type = B_type_raw; using B_type = B_type_raw;
using C_type = C_type_raw; using C_type = C_type_raw;
using InstructionShape = GemmShape<16, 16, 4>; using InstructionShape = GemmShape<16, 16, 4>;
using SMemLayoutA = using SMemLayoutA =
typename DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM, Shape::kK>::Layout; typename DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM,
Shape::kK>::Layout;
using SMemLayoutB = using SMemLayoutB =
typename DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN, Shape::kK>::Layout; typename DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN,
Shape::kK>::Layout;
static constexpr int stride_A = static constexpr int stride_A =
DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM, Shape::kK>::Stride; DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM,
Shape::kK>::Stride;
static constexpr int stride_B = static constexpr int stride_B =
DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN, Shape::kK>::Stride; DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN,
Shape::kK>::Stride;
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<InstructionShape, 32, A_type, cutlass::arch::Mma<
typename std::conditional<trans_A, cutlass::layout::ColumnMajor, InstructionShape, 32, A_type,
cutlass::layout::RowMajor>::type, typename std::conditional<trans_A, cutlass::layout::ColumnMajor,
B_type, cutlass::layout::RowMajor>::type,
typename std::conditional<trans_B, cutlass::layout::ColumnMajor, B_type,
cutlass::layout::RowMajor>::type, typename std::conditional<trans_B, cutlass::layout::ColumnMajor,
C_type, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, cutlass::layout::RowMajor>::type,
cutlass::MatrixShape<1, 1> >; C_type, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>,
cutlass::MatrixShape<1, 1>>;
static_assert(Shape::kM % num_warp_m == 0); static_assert(Shape::kM % num_warp_m == 0);
static_assert(Shape::kN % num_warp_n == 0); static_assert(Shape::kN % num_warp_n == 0);
using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp< using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp<
GemmShape<Shape::kM / num_warp_m, Shape::kN / num_warp_n, InstructionShape::kK>, A_type, GemmShape<Shape::kM / num_warp_m, Shape::kN / num_warp_n,
SMemLayoutA, B_type, SMemLayoutB, C_type, cutlass::layout::RowMajor, Policy>; InstructionShape::kK>,
A_type, SMemLayoutA, B_type, SMemLayoutB, C_type,
cutlass::layout::RowMajor, Policy>;
using TensorRefA = typename MmaWarp::IteratorA::TensorRef; using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
using TensorRefB = typename MmaWarp::IteratorB::TensorRef; using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
...@@ -97,13 +113,14 @@ class GemmTensorOp { ...@@ -97,13 +113,14 @@ class GemmTensorOp {
static_assert(Shape::kK % InstructionShape::kK == 0); static_assert(Shape::kK % InstructionShape::kK == 0);
static int constexpr kKgroups = Shape::kK / InstructionShape::kK; static int constexpr kKgroups = Shape::kK / InstructionShape::kK;
static CUTLASS_DEVICE void body(A_type_raw* pA, B_type_raw* pB, FragmentC& accum, static CUTLASS_DEVICE void body(A_type_raw *pA, B_type_raw *pB,
const int warp_idx_m, const int warp_idx_n, const int lane_id) { FragmentC &accum, const int warp_idx_m,
const int warp_idx_n, const int lane_id) {
MmaWarp mma_op; MmaWarp mma_op;
FragmentA frag_A; FragmentA frag_A;
FragmentB frag_B; FragmentB frag_B;
const TensorRefA ref_A((A_type*)pA, stride_A); const TensorRefA ref_A((A_type *)pA, stride_A);
const TensorRefB ref_B((B_type*)pB, stride_B); const TensorRefB ref_B((B_type *)pB, stride_B);
IteratorA iter_A(ref_A, lane_id); IteratorA iter_A(ref_A, lane_id);
IteratorB iter_B(ref_B, lane_id); IteratorB iter_B(ref_B, lane_id);
iter_A.add_tile_offset({warp_idx_m, 0}); iter_A.add_tile_offset({warp_idx_m, 0});
...@@ -118,11 +135,12 @@ class GemmTensorOp { ...@@ -118,11 +135,12 @@ class GemmTensorOp {
} }
} }
static CUTLASS_DEVICE void body_rs(const FragmentA* frag_A, B_type_raw* pB, FragmentC& accum, static CUTLASS_DEVICE void body_rs(const FragmentA *frag_A, B_type_raw *pB,
const int warp_idx_n, const int lane_id) { FragmentC &accum, const int warp_idx_n,
const int lane_id) {
MmaWarp mma_op; MmaWarp mma_op;
FragmentB frag_B; FragmentB frag_B;
const TensorRefB ref_B((B_type*)pB, stride_B); const TensorRefB ref_B((B_type *)pB, stride_B);
IteratorB iter_B(ref_B, lane_id); IteratorB iter_B(ref_B, lane_id);
iter_B.add_tile_offset({0, warp_idx_n}); iter_B.add_tile_offset({0, warp_idx_n});
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
...@@ -136,27 +154,29 @@ class GemmTensorOp { ...@@ -136,27 +154,29 @@ class GemmTensorOp {
namespace tl { namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type, typename B_type, typename C_type> bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) { CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A, trans_B, A_type, using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
B_type, C_type>; trans_B, A_type, B_type, C_type>;
using FragmentC = typename MMA::FragmentC; using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32; int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32; int lane_id = threadIdx.x % 32;
MMA::body(pA, pB, *(FragmentC*)(accum), warp_id / num_warp_n, warp_id % num_warp_n, lane_id); MMA::body(pA, pB, *(FragmentC *)(accum), warp_id / num_warp_n,
warp_id % num_warp_n, lane_id);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type, typename B_type, typename C_type> bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) { CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A, trans_B, A_type, using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
B_type, C_type>; trans_B, A_type, B_type, C_type>;
using FragmentA = typename MMA::FragmentA; using FragmentA = typename MMA::FragmentA;
using FragmentC = typename MMA::FragmentC; using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32; int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32; int lane_id = threadIdx.x % 32;
MMA::body_rs((const FragmentA*)(pA), pB, *(FragmentC*)(accum), warp_id % num_warp_n, lane_id); MMA::body_rs((const FragmentA *)(pA), pB, *(FragmentC *)(accum),
warp_id % num_warp_n, lane_id);
} }
}; // namespace tl }; // namespace tl
...@@ -12,39 +12,32 @@ template <typename A_type, typename B_type, typename C_type> ...@@ -12,39 +12,32 @@ template <typename A_type, typename B_type, typename C_type>
struct DispatchInstruction; struct DispatchInstruction;
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800)) #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template <> template <> struct DispatchInstruction<half_t, half_t, half_t> {
struct DispatchInstruction<half_t, half_t, half_t> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>; using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>; using MMA_Group = Layout<Shape<_1, _2, _1>>;
}; };
template <> template <> struct DispatchInstruction<half_t, half_t, float> {
struct DispatchInstruction<half_t, half_t, float> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>; using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>; using MMA_Group = Layout<Shape<_1, _2, _1>>;
}; };
template <> template <> struct DispatchInstruction<bfloat16_t, bfloat16_t, float> {
struct DispatchInstruction<bfloat16_t, bfloat16_t, float> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>; using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>; using MMA_Group = Layout<Shape<_1, _2, _1>>;
}; };
template <> template <> struct DispatchInstruction<tfloat32_t, tfloat32_t, float> {
struct DispatchInstruction<tfloat32_t, tfloat32_t, float> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>; using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>; using MMA_Group = Layout<Shape<_1, _2, _1>>;
}; };
template <> template <> struct DispatchInstruction<int8_t, int8_t, int> {
struct DispatchInstruction<int8_t, int8_t, int> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>; using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>; using MMA_Group = Layout<Shape<_1, _2, _1>>;
}; };
template <> template <> struct DispatchInstruction<double, double, double> {
struct DispatchInstruction<double, double, double> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>; using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Layout<Shape<_2, _2, _1>>; using MMA_Group = Layout<Shape<_2, _2, _1>>;
}; };
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <> template <> struct DispatchInstruction<half_t, half_t, float> {
struct DispatchInstruction<half_t, half_t, float> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>; using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _2>>; using MMA_Group = Layout<Shape<_1, _2, _2>>;
}; };
...@@ -54,149 +47,175 @@ template <int Bits, int N, int K, bool K_inner, typename Enable = void> ...@@ -54,149 +47,175 @@ template <int Bits, int N, int K, bool K_inner, typename Enable = void>
struct OperandTraits { struct OperandTraits {
// Primary template, use padded layout and default copy // Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N; static constexpr int stride = K_inner ? K : N;
static constexpr int padded = stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride; static constexpr int padded =
using Layout = stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
typename std::conditional<K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>, using Layout = typename std::conditional<
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type; K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
using Copy = DefaultCopy; using Copy = DefaultCopy;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<16, N, K, true, typename std::enable_if<K % 64 == 32>::type> { struct OperandTraits<16, N, K, true,
using LayoutAtom = typename std::enable_if<K % 64 == 32>::type> {
decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{})); using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = SM75_U32x4_LDSM_N;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<16, N, K, true, typename std::enable_if<K % 64 == 0>::type> { struct OperandTraits<16, N, K, true,
using LayoutAtom = typename std::enable_if<K % 64 == 0>::type> {
decltype(composition(Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{})); using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = SM75_U32x4_LDSM_N;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<16, N, K, false, typename std::enable_if<N % 64 == 32>::type> { struct OperandTraits<16, N, K, false,
using LayoutAtom = typename std::enable_if<N % 64 == 32>::type> {
decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{})); using LayoutAtom = decltype(composition(
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{})); Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T; using Copy = SM75_U16x8_LDSM_T;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<16, N, K, false, typename std::enable_if<N % 64 == 0>::type> { struct OperandTraits<16, N, K, false,
using LayoutAtom = typename std::enable_if<N % 64 == 0>::type> {
decltype(composition(Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{})); using LayoutAtom = decltype(composition(
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{})); Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T; using Copy = SM75_U16x8_LDSM_T;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<32, N, K, true, typename std::enable_if<K % 32 == 0>::type> { struct OperandTraits<32, N, K, true,
using LayoutAtom = typename std::enable_if<K % 32 == 0>::type> {
decltype(composition(Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{})); using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = SM75_U32x4_LDSM_N;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<32, N, K, true, typename std::enable_if<K % 32 == 16>::type> { struct OperandTraits<32, N, K, true,
using LayoutAtom = typename std::enable_if<K % 32 == 16>::type> {
decltype(composition(Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{})); using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = SM75_U32x4_LDSM_N;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<32, N, K, false, typename std::enable_if<N % 32 == 0>::type> { struct OperandTraits<32, N, K, false,
using LayoutAtom = typename std::enable_if<N % 32 == 0>::type> {
decltype(composition(Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{})); using LayoutAtom = decltype(composition(
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{})); Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>; using Copy = UniversalCopy<tfloat32_t>;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<32, N, K, false, typename std::enable_if<N % 32 == 16>::type> { struct OperandTraits<32, N, K, false,
using LayoutAtom = typename std::enable_if<N % 32 == 16>::type> {
decltype(composition(Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{})); using LayoutAtom = decltype(composition(
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{})); Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>; using Copy = UniversalCopy<tfloat32_t>;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<8, N, K, true, typename std::enable_if<K % 128 == 64>::type> { struct OperandTraits<8, N, K, true,
using LayoutAtom = typename std::enable_if<K % 128 == 64>::type> {
decltype(composition(Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{})); using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = SM75_U32x4_LDSM_N;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<8, N, K, true, typename std::enable_if<K % 128 == 0>::type> { struct OperandTraits<8, N, K, true,
using LayoutAtom = typename std::enable_if<K % 128 == 0>::type> {
decltype(composition(Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{})); using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = SM75_U32x4_LDSM_N;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<64, N, K, true, typename std::enable_if<K % 16 == 0>::type> { struct OperandTraits<64, N, K, true,
using LayoutAtom = typename std::enable_if<K % 16 == 0>::type> {
decltype(composition(Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{})); using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = DefaultCopy; using Copy = DefaultCopy;
}; };
template <int N, int K> template <int N, int K>
struct OperandTraits<64, N, K, false, typename std::enable_if<N % 16 == 0>::type> { struct OperandTraits<64, N, K, false,
using LayoutAtom = typename std::enable_if<N % 16 == 0>::type> {
decltype(composition(Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{})); using LayoutAtom = decltype(composition(
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{})); Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = DefaultCopy; using Copy = DefaultCopy;
}; };
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type_raw, typename B_type_raw, typename C_type_raw> bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp { class GemmTensorOp {
public: public:
using A_type = typename std::conditional<std::is_same<A_type_raw, float>::value, tfloat32_t, using A_type =
A_type_raw>::type; typename std::conditional<std::is_same<A_type_raw, float>::value,
using B_type = typename std::conditional<std::is_same<B_type_raw, float>::value, tfloat32_t, tfloat32_t, A_type_raw>::type;
A_type_raw>::type; using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw; using C_type = C_type_raw;
using Instruction = DispatchInstruction<A_type, B_type, C_type>; using Instruction = DispatchInstruction<A_type, B_type, C_type>;
using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>; using OperandATraits =
using OperandBTraits = OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>; OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>;
using SmemLayoutA = typename OperandATraits::Layout; using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout; using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>; using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>; using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;
using TileMma = using TileMma = TiledMMA<typename Instruction::MMA,
TiledMMA<typename Instruction::MMA, Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>, Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
typename Instruction::MMA_Group>; typename Instruction::MMA_Group>;
template <class... Args> template <class... Args>
static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const& layout) { static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const &layout) {
return layout; return layout;
} }
// In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0 // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
// the original layout fail to compile, currently using this as a workaround // the original layout fail to compile, currently using this as a workaround
template <class... Args> template <class... Args>
static CUTE_DEVICE auto remove_swizzle(ComposedLayout<Args...> const& layout) { static CUTE_DEVICE auto
remove_swizzle(ComposedLayout<Args...> const &layout) {
if constexpr (sizeof(A_type) == 2) if constexpr (sizeof(A_type) == 2)
return layout.layout_b(); return layout.layout_b();
else else
return layout; return layout;
} }
static CUTE_DEVICE void body(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) { static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
const int tid = threadIdx.x; const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type*>(pA)), SmemLayoutA{}); Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{}); SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
TileMma tiled_mma; TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid); auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
...@@ -212,10 +231,12 @@ class GemmTensorOp { ...@@ -212,10 +231,12 @@ class GemmTensorOp {
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)), Tensor acc =
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{})); make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
// when layout is KxN and n_warp is 1, there seem to be a bug, use this as a workaround // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
// workaround
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
CUTE_UNROLL CUTE_UNROLL
...@@ -226,9 +247,11 @@ class GemmTensorOp { ...@@ -226,9 +247,11 @@ class GemmTensorOp {
} }
} }
static CUTE_DEVICE void body_rs(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) { static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x; const int tid = threadIdx.x;
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{}); Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
TileMma tiled_mma; TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid); auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
...@@ -239,10 +262,12 @@ class GemmTensorOp { ...@@ -239,10 +262,12 @@ class GemmTensorOp {
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)), Tensor acc =
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{})); make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast<A_type*>(pA)), partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{})); Tensor tCrA =
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0)); copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
...@@ -255,9 +280,11 @@ class GemmTensorOp { ...@@ -255,9 +280,11 @@ class GemmTensorOp {
} }
} }
static CUTE_DEVICE void body_sr(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) { static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x; const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type*>(pA)), SmemLayoutA{}); Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
TileMma tiled_mma; TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid); auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
...@@ -268,10 +295,12 @@ class GemmTensorOp { ...@@ -268,10 +295,12 @@ class GemmTensorOp {
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)), Tensor acc =
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{})); make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
Tensor tCrB = make_tensor(make_rmem_ptr(reinterpret_cast<B_type*>(pB)), partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{})); Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0)); copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
...@@ -285,32 +314,32 @@ class GemmTensorOp { ...@@ -285,32 +314,32 @@ class GemmTensorOp {
} }
}; };
} // namespace cute } // namespace cute
namespace tl { namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type, typename B_type, typename C_type> bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) { CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>; trans_B, A_type, B_type, C_type>;
MMA::body(pA, pB, accum); MMA::body(pA, pB, accum);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type, typename B_type, typename C_type> bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) { CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>; trans_B, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum); MMA::body_rs(pA, pB, accum);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type, typename B_type, typename C_type> bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type* pA, B_type* pB, C_type* accum) { CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA = using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>; trans_B, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum); MMA::body_sr(pA, pB, accum);
} }
} // namespace tl } // namespace tl
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
// Licensed under the MIT License. // Licensed under the MIT License.
#pragma once #pragma once
#include <cutlass/cutlass.h>
#include <cutlass/arch/barrier.h>
#include <cute/algorithm/copy.hpp> #include <cute/algorithm/copy.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
#include "common.h" #include "common.h"
...@@ -19,78 +19,112 @@ CUTE_HOST_DEVICE constexpr auto ss_smem_selector() { ...@@ -19,78 +19,112 @@ CUTE_HOST_DEVICE constexpr auto ss_smem_selector() {
static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8.");
if constexpr (major == GMMA::Major::MN) { if constexpr (major == GMMA::Major::MN) {
if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom<ElementType>{}) == 0) { if constexpr (BLK_MN0 %
size<0>(GMMA::Layout_MN_SW128_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_SW128_Atom<ElementType>{}; return GMMA::Layout_MN_SW128_Atom<ElementType>{};
} else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom<ElementType>{}) == 0) { } else if constexpr (BLK_MN0 %
size<0>(
GMMA::Layout_MN_SW64_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_SW64_Atom<ElementType>{}; return GMMA::Layout_MN_SW64_Atom<ElementType>{};
} else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom<ElementType>{}) == 0) { } else if constexpr (BLK_MN0 %
size<0>(
GMMA::Layout_MN_SW32_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_SW32_Atom<ElementType>{}; return GMMA::Layout_MN_SW32_Atom<ElementType>{};
} else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{}) == 0) { } else if constexpr (BLK_MN0 %
size<0>(
GMMA::Layout_MN_INTER_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_INTER_Atom<ElementType>{}; return GMMA::Layout_MN_INTER_Atom<ElementType>{};
} else { } else {
static_assert( static_assert(
BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{}) == 0, BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{}) == 0,
"BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{})"); "BLK_MN0 must be a multiple of "
"size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{})");
} }
} else if constexpr (major == GMMA::Major::K) { } else if constexpr (major == GMMA::Major::K) {
if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom<ElementType>{}) == 0) { if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_SW128_Atom<ElementType>{}; return GMMA::Layout_K_SW128_Atom<ElementType>{};
} else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom<ElementType>{}) == 0) { } else if constexpr (BLK_K0 %
size<1>(GMMA::Layout_K_SW64_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_SW64_Atom<ElementType>{}; return GMMA::Layout_K_SW64_Atom<ElementType>{};
} else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom<ElementType>{}) == 0) { } else if constexpr (BLK_K0 %
size<1>(GMMA::Layout_K_SW32_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_SW32_Atom<ElementType>{}; return GMMA::Layout_K_SW32_Atom<ElementType>{};
} else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{}) == 0) { } else if constexpr (BLK_K0 %
size<1>(
GMMA::Layout_K_INTER_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_INTER_Atom<ElementType>{}; return GMMA::Layout_K_INTER_Atom<ElementType>{};
} else { } else {
static_assert( static_assert(
BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{}) == 0, BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{}) == 0,
"BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{})"); "BLK_K0 must be a multiple of "
"size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{})");
} }
} }
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type_raw, typename B_type_raw, typename C_type_raw> bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp { class GemmTensorOp {
public: public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value, tfloat32_t, A_type_raw>; using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
using B_type = conditional_t<std::is_same<B_type_raw, float>::value, tfloat32_t, B_type_raw>; tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using C_type = C_type_raw; using C_type = C_type_raw;
static constexpr GMMA::Major GmmaMajorA = trans_A ? GMMA::Major::MN : GMMA::Major::K; static constexpr GMMA::Major GmmaMajorA =
static constexpr GMMA::Major GmmaMajorB = trans_B ? GMMA::Major::K : GMMA::Major::MN; trans_A ? GMMA::Major::MN : GMMA::Major::K;
static constexpr GMMA::Major GmmaMajorB =
trans_B ? GMMA::Major::K : GMMA::Major::MN;
using SmemLayoutAtomA = decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>()); using SmemLayoutAtomA =
using SmemLayoutAtomB = decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>()); decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
using SmemLayoutAtomB =
decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{}, using SmemLayoutA = decltype(tile_to_shape(
conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{})); SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{}, conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{})); using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
// static_assert(num_warp_n == 1); // static_assert(num_warp_n == 1);
static_assert(num_warp_m % 4 == 0); static_assert(num_warp_m % 4 == 0);
template <int wg_wait=0> template <int wg_wait = 0>
static CUTE_DEVICE void body(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) { static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
const int tid = threadIdx.x; const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type*>(pA)), SmemLayoutA{}); Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{}); SmemLayoutA{});
auto tiled_mma = Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
make_tiled_mma(GMMA::ss_op_selector<A_type, B_type, C_type, Shape<Int<M>, Int<N / num_warp_n>, Int<K>>, SmemLayoutB{});
GmmaMajorA, GmmaMajorB>(), auto tiled_mma = make_tiled_mma(
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{}); GMMA::ss_op_selector<A_type, B_type, C_type,
Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
auto thr_mma = tiled_mma.get_thread_slice(tid); auto thr_mma = tiled_mma.get_thread_slice(tid);
// Allocate registers for pipelining // Allocate registers for pipelining
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)), Tensor acc =
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{})); make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
warpgroup_fence_operand(acc); warpgroup_fence_operand(acc);
warpgroup_arrive(); warpgroup_arrive();
...@@ -103,7 +137,9 @@ class GemmTensorOp { ...@@ -103,7 +137,9 @@ class GemmTensorOp {
} }
warpgroup_commit_batch(); warpgroup_commit_batch();
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); } if constexpr (wg_wait >= 0) {
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(acc); warpgroup_fence_operand(acc);
// warpgroup_fence_operand(acc); // warpgroup_fence_operand(acc);
// warpgroup_arrive(); // warpgroup_arrive();
...@@ -115,25 +151,31 @@ class GemmTensorOp { ...@@ -115,25 +151,31 @@ class GemmTensorOp {
// warpgroup_fence_operand(acc); // warpgroup_fence_operand(acc);
} }
template <int wg_wait=0> template <int wg_wait = 0>
static CUTE_DEVICE void body_rs(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) { static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
// TODO: Move bar.sync out of body_rs // TODO: Move bar.sync out of body_rs
// asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n * 32)); // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n *
// 32));
const int tid = threadIdx.x; const int tid = threadIdx.x;
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{}); Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
auto tiled_mma = SmemLayoutB{});
make_tiled_mma(GMMA::rs_op_selector<A_type, B_type, C_type, Shape<Int<M>, Int<N / num_warp_n>, Int<K>>, auto tiled_mma = make_tiled_mma(
GmmaMajorA, GmmaMajorB>(), GMMA::rs_op_selector<A_type, B_type, C_type,
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{}); Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
auto thr_mma = tiled_mma.get_thread_slice(tid); auto thr_mma = tiled_mma.get_thread_slice(tid);
// Allocate registers for pipelining // Allocate registers for pipelining
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast<A_type*>(pA)), Tensor tCrA =
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{})); make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)), partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{})); Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
warpgroup_fence_operand(tCrA); warpgroup_fence_operand(tCrA);
warpgroup_fence_operand(acc); warpgroup_fence_operand(acc);
...@@ -146,7 +188,9 @@ class GemmTensorOp { ...@@ -146,7 +188,9 @@ class GemmTensorOp {
tiled_mma.accumulate_ = GMMA::ScaleOut::One; tiled_mma.accumulate_ = GMMA::ScaleOut::One;
} }
warpgroup_commit_batch(); warpgroup_commit_batch();
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); } if constexpr (wg_wait >= 0) {
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(acc); warpgroup_fence_operand(acc);
warpgroup_fence_operand(tCrA); warpgroup_fence_operand(tCrA);
...@@ -156,57 +200,63 @@ class GemmTensorOp { ...@@ -156,57 +200,63 @@ class GemmTensorOp {
// gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc); // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);
// warpgroup_commit_batch(); // warpgroup_commit_batch();
// if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); } // if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
// warpgroup_fence_operand(acc); // warpgroup_fence_operand(acc);
} }
}; };
} // namespace cute } // namespace cute
namespace tl { namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int wg_wait=0, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type, typename B_type, typename C_type> bool trans_B, int wg_wait = 0, typename A_type, typename B_type,
TL_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) { typename C_type>
using MMA = TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>; using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body<wg_wait>(pA, pB, accum); MMA::body<wg_wait>(pA, pB, accum);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int wg_wait=0, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type, typename B_type, typename C_type> bool trans_B, int wg_wait = 0, typename A_type, typename B_type,
TL_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) { typename C_type>
using MMA = TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>; using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body_rs<wg_wait>(pA, pB, accum); MMA::body_rs<wg_wait>(pA, pB, accum);
} }
template <int num_mma> template <int num_mma> TL_DEVICE void wait_wgmma() {
TL_DEVICE void wait_wgmma() {
warpgroup_wait<num_mma>(); warpgroup_wait<num_mma>();
} }
template <int NumMmaThreads> template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_sync() {
TL_DEVICE void warp_scheduler_barrier_sync() { cutlass::arch::NamedBarrier::sync(NumMmaThreads,
cutlass::arch::NamedBarrier::sync( cutlass::canonical_warp_group_idx() /*id*/);
NumMmaThreads,
cutlass::canonical_warp_group_idx() /*id*/);
} }
template <int NumMmaThreads> template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_arrive() {
TL_DEVICE void warp_scheduler_barrier_arrive() {
static_assert(NumMmaThreads == 256 || NumMmaThreads == 384); static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
if constexpr (NumMmaThreads == 256) { if constexpr (NumMmaThreads == 256) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/); cutlass::arch::NamedBarrier::arrive(
NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/);
} else { } else {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/); cutlass::arch::NamedBarrier::arrive(
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, (cutlass::canonical_warp_group_idx() <= 0 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/); NumMmaThreads,
(cutlass::canonical_warp_group_idx() <= 1
? cutlass::canonical_warp_group_idx() + 1
: cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
cutlass::arch::NamedBarrier::arrive(
NumMmaThreads,
(cutlass::canonical_warp_group_idx() <= 0
? cutlass::canonical_warp_group_idx() + 2
: cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
} }
} }
template <int NumMmaThreads> template <int NumMmaThreads> TL_DEVICE void mma_init() {
TL_DEVICE void mma_init() {
static_assert(NumMmaThreads == 256 || NumMmaThreads == 384); static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
if (cutlass::canonical_warp_group_idx() > 0) { if (cutlass::canonical_warp_group_idx() > 0) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0);
...@@ -217,4 +267,4 @@ TL_DEVICE void mma_init() { ...@@ -217,4 +267,4 @@ TL_DEVICE void mma_init() {
} }
} }
} }
} // namespace tl } // namespace tl
...@@ -6,97 +6,118 @@ ...@@ -6,97 +6,118 @@
namespace tl { namespace tl {
TL_DEVICE void ptx_ldmatrix_x1(void const* const smem_ptr, void* const local_ptr) { TL_DEVICE void ptx_ldmatrix_x1(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(value[0]) : "=r"(value[0])
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE void ptx_ldmatrix_x2(void const* const smem_ptr, void* const local_ptr) { TL_DEVICE void ptx_ldmatrix_x2(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(value[0]), "=r"(value[1]) : "=r"(value[0]), "=r"(value[1])
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE void ptx_ldmatrix_x4(void const* const smem_ptr, void* const local_ptr) { TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" asm volatile(
: "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "r"(smem_int_ptr)); : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
: "r"(smem_int_ptr));
} }
TL_DEVICE void ptx_ldmatrix_x1_trans(void const* const smem_ptr, void* const local_ptr) { TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(value[0]) : "=r"(value[0])
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE void ptx_ldmatrix_x2_trans(void const* const smem_ptr, void* const local_ptr) { TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" asm volatile(
: "=r"(value[0]), "=r"(value[1]) "ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "r"(smem_int_ptr)); : "=r"(value[0]), "=r"(value[1])
: "r"(smem_int_ptr));
} }
TL_DEVICE void ptx_ldmatrix_x4_trans(void const* const smem_ptr, void* const local_ptr) { TL_DEVICE void ptx_ldmatrix_x4_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" asm volatile(
: "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "r"(smem_int_ptr)); : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
: "r"(smem_int_ptr));
} }
TL_DEVICE void ptx_stmatrix_x1(void const* const smem_ptr, const int32_t& value0) { TL_DEVICE void ptx_stmatrix_x1(void const *const smem_ptr,
const int32_t &value0) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" ::"r"(smem_int_ptr), asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" ::"r"(
smem_int_ptr),
"r"(value0)); "r"(value0));
} }
TL_DEVICE void ptx_stmatrix_x2(void const* const smem_ptr, const int32_t& value0, TL_DEVICE void ptx_stmatrix_x2(void const *const smem_ptr,
const int32_t& value1) { const int32_t &value0, const int32_t &value1) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(smem_int_ptr), asm volatile(
"r"(value0), "r"(value1)); "stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(
smem_int_ptr),
"r"(value0), "r"(value1));
} }
TL_DEVICE void ptx_stmatrix_x4(void const* const smem_ptr, const int32_t& value0, TL_DEVICE void ptx_stmatrix_x4(void const *const smem_ptr,
const int32_t& value1, const int32_t& value2, const int32_t &value0, const int32_t &value1,
const int32_t& value3) { const int32_t &value2, const int32_t &value3) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile( asm volatile(
"stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(smem_int_ptr), "stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" ::
"r"(smem_int_ptr),
"r"(value0), "r"(value1), "r"(value2), "r"(value3)); "r"(value0), "r"(value1), "r"(value2), "r"(value3));
} }
TL_DEVICE void ptx_stmatrix_x1_trans(void const* const smem_ptr, const int32_t& value0) { TL_DEVICE void ptx_stmatrix_x1_trans(void const *const smem_ptr,
const int32_t &value0) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" ::"r"(smem_int_ptr), asm volatile(
"r"(value0)); "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" ::"r"(
smem_int_ptr),
"r"(value0));
} }
TL_DEVICE void ptx_stmatrix_x2_trans(void const* const smem_ptr, const int32_t& value0, TL_DEVICE void ptx_stmatrix_x2_trans(void const *const smem_ptr,
const int32_t& value1) { const int32_t &value0,
const int32_t &value1) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile( asm volatile(
"stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(smem_int_ptr), "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(
smem_int_ptr),
"r"(value0), "r"(value1)); "r"(value0), "r"(value1));
} }
TL_DEVICE void ptx_stmatrix_x4_trans(void const* const smem_ptr, const int32_t& value0, TL_DEVICE void ptx_stmatrix_x4_trans(void const *const smem_ptr,
const int32_t& value1, const int32_t& value2, const int32_t &value0,
const int32_t& value3) { const int32_t &value1,
const int32_t &value2,
const int32_t &value3) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"( asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, "
smem_int_ptr), "%3, %4};\n" ::"r"(smem_int_ptr),
"r"(value0), "r"(value1), "r"(value2), "r"(value3)); "r"(value0), "r"(value1), "r"(value2), "r"(value3));
} }
} // namespace tl } // namespace tl
\ No newline at end of file \ No newline at end of file
...@@ -7,34 +7,29 @@ ...@@ -7,34 +7,29 @@
namespace tl { namespace tl {
struct SumOp { struct SumOp {
template <typename T> template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
TL_DEVICE T operator()(T const& x, T const& y) {
return x + y; return x + y;
} }
}; };
struct MaxOp { struct MaxOp {
template <typename T> template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
TL_DEVICE T operator()(T const& x, T const& y) {
return cutlass::fast_max(x, y); return cutlass::fast_max(x, y);
} }
}; };
struct MinOp { struct MinOp {
template <typename T> template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
TL_DEVICE T operator()(T const& x, T const& y) {
return cutlass::fast_min(x, y); return cutlass::fast_min(x, y);
} }
}; };
template <class Reducer, int threads, int scale> template <class Reducer, int threads, int scale> struct AllReduce {
struct AllReduce { static_assert(threads == 1024 or threads == 512 or threads == 256 or
static_assert(threads == 1024 or threads == 512 or threads == 256 or threads == 128 or threads == 128 or threads == 64 or threads == 32 or
threads == 64 or threads == 32 or threads == 16 or threads == 8 or threads == 4 or threads == 16 or threads == 8 or threads == 4 or threads == 2);
threads == 2);
static_assert(threads % scale == 0); static_assert(threads % scale == 0);
template <typename T> template <typename T> static TL_DEVICE T run(T x, T *red_buf = nullptr) {
static TL_DEVICE T run(T x, T* red_buf = nullptr) {
constexpr int offset = threads / 2; constexpr int offset = threads / 2;
if constexpr (offset >= 32) { if constexpr (offset >= 32) {
__syncthreads(); __syncthreads();
...@@ -54,4 +49,4 @@ struct AllReduce { ...@@ -54,4 +49,4 @@ struct AllReduce {
} }
}; };
} // namespace tl } // namespace tl
...@@ -6,8 +6,7 @@ ...@@ -6,8 +6,7 @@
namespace tl { namespace tl {
template <int panel_width> template <int panel_width> TL_DEVICE dim3 rasterization2DRow() {
TL_DEVICE dim3 rasterization2DRow() {
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y; const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.x; const unsigned int panel_size = panel_width * gridDim.x;
...@@ -15,15 +14,17 @@ TL_DEVICE dim3 rasterization2DRow() { ...@@ -15,15 +14,17 @@ TL_DEVICE dim3 rasterization2DRow() {
const unsigned int panel_idx = block_idx / panel_size; const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size); const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size);
const unsigned int stride = const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.x; panel_idx + 1 < total_panel
const unsigned int col_idx = ? panel_width
(panel_idx & 1) ? gridDim.x - 1 - panel_offset / stride : panel_offset / stride; : (grid_size - panel_idx * panel_size) / gridDim.x;
const unsigned int col_idx = (panel_idx & 1)
? gridDim.x - 1 - panel_offset / stride
: panel_offset / stride;
const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width; const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z}; return {col_idx, row_idx, blockIdx.z};
} }
template <int panel_width> template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
TL_DEVICE dim3 rasterization2DColumn() {
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y; const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.y; const unsigned int panel_size = panel_width * gridDim.y;
...@@ -31,11 +32,14 @@ TL_DEVICE dim3 rasterization2DColumn() { ...@@ -31,11 +32,14 @@ TL_DEVICE dim3 rasterization2DColumn() {
const unsigned int panel_idx = block_idx / panel_size; const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size); const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size);
const unsigned int stride = const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.y; panel_idx + 1 < total_panel
const unsigned int row_idx = ? panel_width
(panel_idx & 1) ? gridDim.y - 1 - panel_offset / stride : panel_offset / stride; : (grid_size - panel_idx * panel_size) / gridDim.y;
const unsigned int row_idx = (panel_idx & 1)
? gridDim.y - 1 - panel_offset / stride
: panel_offset / stride;
const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width; const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z}; return {col_idx, row_idx, blockIdx.z};
} }
} // namespace tl } // namespace tl
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
// Licensed under the MIT License. // Licensed under the MIT License.
#pragma once #pragma once
#include <hip/hip_runtime.h> #include <ck_tile/core.hpp>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <rocwmma/rocwmma.hpp> #include <rocwmma/rocwmma.hpp>
#include <ck_tile/core.hpp>
using ck_tile::half_t; using ck_tile::half_t;
...@@ -36,12 +36,16 @@ using ck_tile::half_t; ...@@ -36,12 +36,16 @@ using ck_tile::half_t;
using float16_t = _Float16; using float16_t = _Float16;
using float16x2 = __attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t; using float16x2 =
using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t; __attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t;
using float16x8 = __attribute__((__vector_size__(8 * sizeof(float16_t)))) float16_t; using float16x4 =
using float16x16 = __attribute__((__vector_size__(16 * sizeof(float16_t)))) float16_t; __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
using float16x8 =
__attribute__((__vector_size__(8 * sizeof(float16_t)))) float16_t;
using float16x16 =
__attribute__((__vector_size__(16 * sizeof(float16_t)))) float16_t;
using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
...@@ -49,7 +53,7 @@ using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t; ...@@ -49,7 +53,7 @@ using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t;
// Pack two half_t values. // Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) { TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
unsigned v0 = *((unsigned short*)&x); unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short*)&y); unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
...@@ -16,12 +16,13 @@ using index_t = u32; ...@@ -16,12 +16,13 @@ using index_t = u32;
using ck_tile::int32x4_t; using ck_tile::int32x4_t;
struct __attribute__((packed)) buffer_resource { struct __attribute__((packed)) buffer_resource {
const void* ptr; const void *ptr;
uint32_t range; uint32_t range;
uint32_t config; uint32_t config;
}; };
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) { CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr,
uint32_t size = 0xffffffff) {
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res); int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x); r.x = __builtin_amdgcn_readfirstlane(r.x);
...@@ -56,48 +57,56 @@ __device__ void async_gld_sld_fence(index_t cnt) { ...@@ -56,48 +57,56 @@ __device__ void async_gld_sld_fence(index_t cnt) {
__device__ void wave_barrier() { asm volatile("s_barrier" : : : "memory"); } __device__ void wave_barrier() { asm volatile("s_barrier" : : : "memory"); }
template <int N = 0> template <int N = 0> TL_DEVICE void cp_async_wait() {
TL_DEVICE void cp_async_wait() {
async_gld_fence(N); async_gld_fence(N);
// or // or
// async_gld_sld_fence(N); // async_gld_sld_fence(N);
} }
template <bool pre_nop = false> template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, int32x4_t rsrc, index_t voffset) { CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc,
auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(smem))); index_t voffset) {
asm volatile( auto const lds_ptr_sgpr =
"s_mov_b32 m0, %0; \n\t" __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(smem)));
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), asm volatile("s_mov_b32 m0, %0; \n\t"
"v"(voffset), "s"(rsrc) "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
: "memory"); "v"(voffset), "s"(rsrc)
: "memory");
} }
template <int N> template <int N>
TL_DEVICE void cp_async_gs(void* lds_base_ptr, void* global_base_ptr) { TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) {
if constexpr(N == 16) { if constexpr (N == 16) {
*(uint4*)lds_base_ptr = *(uint4*)global_base_ptr; *(uint4 *)lds_base_ptr = *(uint4 *)global_base_ptr;
} else if constexpr(N == 8) { } else if constexpr (N == 8) {
*(uint2*)lds_base_ptr = *(uint2*)global_base_ptr; *(uint2 *)lds_base_ptr = *(uint2 *)global_base_ptr;
} else if constexpr(N == 4) { } else if constexpr (N == 4) {
async_buffer_load_dword_v(lds_base_ptr, make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x), threadIdx.x * N /*assume 4 bytes*/); async_buffer_load_dword_v(
lds_base_ptr,
make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
threadIdx.x * N /*assume 4 bytes*/);
} }
} }
template <int N> template <int N>
TL_DEVICE void cp_async_gs_conditional(void* lds_base_ptr, void* global_base_ptr, bool cond) { TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
if constexpr(N == 16){ void *global_base_ptr, bool cond) {
*(uint4*)lds_base_ptr = cond? *(uint4*)global_base_ptr: make_uint4(0,0,0,0); if constexpr (N == 16) {
}else if constexpr(N == 8){ *(uint4 *)lds_base_ptr =
*(uint2*)lds_base_ptr = cond? *(uint2*)global_base_ptr: make_uint2(0,0); cond ? *(uint4 *)global_base_ptr : make_uint4(0, 0, 0, 0);
}else{ } else if constexpr (N == 8) {
*(uint2 *)lds_base_ptr =
cond ? *(uint2 *)global_base_ptr : make_uint2(0, 0);
} else {
if (cond) { if (cond) {
async_buffer_load_dword_v(lds_base_ptr, make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x), threadIdx.x * N /*assume 4 bytes*/); async_buffer_load_dword_v(
}else{ lds_base_ptr,
*(uint4*)lds_base_ptr = make_uint4(0,0,0,0); make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
threadIdx.x * N /*assume 4 bytes*/);
} else {
*(uint4 *)lds_base_ptr = make_uint4(0, 0, 0, 0);
} }
} }
} }
} // namespace tl } // namespace tl
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
namespace tl { namespace tl {
// ref to bitblas/tl/mfma_macro_generator.py::kPack // ref to bitblas/tl/mfma_macro_generator.py::kPack
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA, bool TransposeB, int kPack, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA,
typename A_type, typename B_type, typename C_type, typename AccDataType = float> bool TransposeB, int kPack, typename A_type, typename B_type,
typename C_type, typename AccDataType = float>
class GemmTensorOp { class GemmTensorOp {
public: public:
static constexpr int micro_size_x = 16; static constexpr int micro_size_x = 16;
static constexpr int micro_size_y = 16; static constexpr int micro_size_y = 16;
static constexpr int micro_size_k = 16; static constexpr int micro_size_k = 16;
...@@ -28,7 +28,8 @@ class GemmTensorOp { ...@@ -28,7 +28,8 @@ class GemmTensorOp {
static constexpr int warp_rows = M_Tile / (block_row_warps * micro_size_x); static constexpr int warp_rows = M_Tile / (block_row_warps * micro_size_x);
static constexpr int warp_cols = N_Tile / (block_col_warps * micro_size_y); static constexpr int warp_cols = N_Tile / (block_col_warps * micro_size_y);
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen
// part.
static constexpr bool kPadA = true; static constexpr bool kPadA = true;
static constexpr bool kPadB = true; static constexpr bool kPadB = true;
static constexpr bool kPadC = true; static constexpr bool kPadC = true;
...@@ -37,12 +38,16 @@ class GemmTensorOp { ...@@ -37,12 +38,16 @@ class GemmTensorOp {
static constexpr int warp_size = 64; static constexpr int warp_size = 64;
TL_DEVICE static constexpr auto reverse_index_map(int thread_id, int local_id) { TL_DEVICE static constexpr auto reverse_index_map(int thread_id,
return std::make_pair(thread_id % 16, (thread_id / 16) * (4 * kPack) + local_id); int local_id) {
return std::make_pair(thread_id % 16,
(thread_id / 16) * (4 * kPack) + local_id);
} }
TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id, int local_id) { TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id,
return std::make_pair((thread_id / 16) * (4 * kPack) + local_id, thread_id % 16); int local_id) {
return std::make_pair((thread_id / 16) * (4 * kPack) + local_id,
thread_id % 16);
} }
/* /*
...@@ -62,7 +67,8 @@ class GemmTensorOp { ...@@ -62,7 +67,8 @@ class GemmTensorOp {
const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
const int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize); const int maxPhase =
std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
const int phase = (row / perPhase) % maxPhase; const int phase = (row / perPhase) % maxPhase;
const int colOffSwizzled = (((col / vecSize) ^ phase) * vecSize); const int colOffSwizzled = (((col / vecSize) ^ phase) * vecSize);
...@@ -73,16 +79,19 @@ class GemmTensorOp { ...@@ -73,16 +79,19 @@ class GemmTensorOp {
} }
template <int continuous = 32, int element_size = 2> template <int continuous = 32, int element_size = 2>
TL_DEVICE static constexpr auto make_layout_padded(const int row, const int col) { TL_DEVICE static constexpr auto make_layout_padded(const int row,
const int col) {
return std::make_pair(row, col); return std::make_pair(row, col);
} }
template <int continuous = 32, int element_size = 2> template <int continuous = 32, int element_size = 2>
TL_DEVICE static constexpr auto make_swizzle_layout(const int row, const int col) { TL_DEVICE static constexpr auto make_swizzle_layout(const int row,
const int col) {
constexpr auto vector_size = BANK_SIZE_BYTES / (element_size * 8); constexpr auto vector_size = BANK_SIZE_BYTES / (element_size * 8);
if (continuous % (vector_size * 4) == 0) { if (continuous % (vector_size * 4) == 0) {
auto [n_row, n_col] = make_mfma_swizzle_layout<continuous, element_size>(row, col); auto [n_row, n_col] =
make_mfma_swizzle_layout<continuous, element_size>(row, col);
return n_row * continuous + n_col; return n_row * continuous + n_col;
} else { } else {
auto [n_row, n_col] = make_layout_padded(row, col); auto [n_row, n_col] = make_layout_padded(row, col);
...@@ -93,7 +102,8 @@ class GemmTensorOp { ...@@ -93,7 +102,8 @@ class GemmTensorOp {
} }
} }
static TL_DEVICE void body(A_type* A_shared, B_type* B_shared, C_type* C_local) { static TL_DEVICE void body(A_type *A_shared, B_type *B_shared,
C_type *C_local) {
auto tid = threadIdx.x; auto tid = threadIdx.x;
auto warp_id = tid / warp_size; auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps; auto warp_n = warp_id / block_row_warps;
...@@ -122,7 +132,8 @@ class GemmTensorOp { ...@@ -122,7 +132,8 @@ class GemmTensorOp {
for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) { for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id); auto [row, col] = reverse_index_map(lane_id, local_id);
A_local[i * kPack * local_size_a + local_id] = A_local[i * kPack * local_size_a + local_id] =
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(l + row, r + col)]; A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
l + row, r + col)];
} }
} }
...@@ -133,7 +144,8 @@ class GemmTensorOp { ...@@ -133,7 +144,8 @@ class GemmTensorOp {
for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) { for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id); auto [row, col] = reverse_index_map(lane_id, local_id);
B_local[j * kPack * local_size_b + local_id] = B_local[j * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(l + row, r + col)]; B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)];
} }
} }
...@@ -141,17 +153,19 @@ class GemmTensorOp { ...@@ -141,17 +153,19 @@ class GemmTensorOp {
for (int kp = 0; kp < kPack; kp++) { for (int kp = 0; kp < kPack; kp++) {
for (int i = 0; i < warp_rows; ++i) { for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) { for (int j = 0; j < warp_cols; ++j) {
*(((float32x4*)C_local) + ((i * warp_cols) + j)) = __builtin_amdgcn_mfma_f32_16x16x16f16( *(((float32x4 *)C_local) + ((i * warp_cols) + j)) =
*(((float16x4*)B_local) + j * kPack + kp), __builtin_amdgcn_mfma_f32_16x16x16f16(
*(((float16x4*)A_local) + i * kPack + kp), *(((float16x4 *)B_local) + j * kPack + kp),
*(((float32x4*)C_local) + ((i * warp_cols) + j)), 0, 0, 0); *(((float16x4 *)A_local) + i * kPack + kp),
*(((float32x4 *)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
} }
} }
} }
} }
} }
static TL_DEVICE void body_rs(A_type* A_local, B_type* B_shared, C_type* C_local) { static TL_DEVICE void body_rs(A_type *A_local, B_type *B_shared,
C_type *C_local) {
auto tid = threadIdx.x; auto tid = threadIdx.x;
auto warp_id = tid / warp_size; auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps; auto warp_n = warp_id / block_row_warps;
...@@ -179,7 +193,8 @@ class GemmTensorOp { ...@@ -179,7 +193,8 @@ class GemmTensorOp {
for (int local_id = 0; local_id < kPack * local_size_b; local_id++) { for (int local_id = 0; local_id < kPack * local_size_b; local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id); auto [row, col] = reverse_index_map(lane_id, local_id);
B_local[j * local_size_b + local_id] = B_local[j * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(l + row, r + col)]; B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)];
} }
} }
...@@ -187,9 +202,12 @@ class GemmTensorOp { ...@@ -187,9 +202,12 @@ class GemmTensorOp {
for (int kp = 0; kp < kPack; kp++) { for (int kp = 0; kp < kPack; kp++) {
for (int i = 0; i < warp_rows; ++i) { for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) { for (int j = 0; j < warp_cols; ++j) {
*(((float32x4*)C_local) + ((i * warp_cols) + j)) = __builtin_amdgcn_mfma_f32_16x16x16f16( *(((float32x4 *)C_local) + ((i * warp_cols) + j)) =
*(((float16x4*)B_local) + j * kPack + kp), *(((float16x4*)A_local) + ki * warp_rows * kPack + i * kPack + kp), __builtin_amdgcn_mfma_f32_16x16x16f16(
*(((float32x4*)C_local) + ((i * warp_cols) + j)), 0, 0, 0); *(((float16x4 *)B_local) + j * kPack + kp),
*(((float16x4 *)A_local) + ki * warp_rows * kPack +
i * kPack + kp),
*(((float32x4 *)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
} }
} }
} }
...@@ -197,24 +215,26 @@ class GemmTensorOp { ...@@ -197,24 +215,26 @@ class GemmTensorOp {
} }
}; };
} // namespace tl } // namespace tl
namespace tl { namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int kPack, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type, typename B_type, typename C_type> bool trans_B, int kPack, typename A_type, typename B_type,
TL_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) { typename C_type>
using Compute = TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, kPack, A_type, B_type, C_type>; using Compute = GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, kPack, A_type, B_type, C_type>;
Compute::body(pA, pB, accum); Compute::body(pA, pB, accum);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int kPack, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type, typename B_type, typename C_type> bool trans_B, int kPack, typename A_type, typename B_type,
TL_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) { typename C_type>
using Compute = TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, kPack, A_type, B_type, C_type>; using Compute = GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, kPack, A_type, B_type, C_type>;
Compute::body_rs(pA, pB, accum); Compute::body_rs(pA, pB, accum);
} }
} // namespace tl } // namespace tl
...@@ -7,35 +7,30 @@ ...@@ -7,35 +7,30 @@
namespace tl { namespace tl {
struct SumOp { struct SumOp {
template <typename T> template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
TL_DEVICE T operator()(T const& x, T const& y) {
return x + y; return x + y;
} }
}; };
struct MaxOp { struct MaxOp {
template <typename T> template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
TL_DEVICE T operator()(T const& x, T const& y) {
return ck_tile::max(x, y); return ck_tile::max(x, y);
} }
}; };
struct MinOp { struct MinOp {
template <typename T> template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
TL_DEVICE T operator()(T const& x, T const& y) {
return ck_tile::min(x, y); return ck_tile::min(x, y);
} }
}; };
template <class Reducer, int threads, int scale> template <class Reducer, int threads, int scale> struct AllReduce {
struct AllReduce { static_assert(threads == 1024 || threads == 512 || threads == 256 ||
static_assert(threads == 1024 || threads == 512 || threads == 256 || threads == 128 || threads == 128 || threads == 64 || threads == 32 ||
threads == 64 || threads == 32 || threads == 16 || threads == 8 || threads == 4 || threads == 16 || threads == 8 || threads == 4 || threads == 2);
threads == 2);
static_assert(threads % scale == 0); static_assert(threads % scale == 0);
template <typename T> template <typename T> static __device__ T run(T x, T *red_buf = nullptr) {
static __device__ T run(T x, T* red_buf = nullptr) {
constexpr int offset = threads / 2; constexpr int offset = threads / 2;
constexpr int warpSize = 64; constexpr int warpSize = 64;
...@@ -55,4 +50,4 @@ struct AllReduce { ...@@ -55,4 +50,4 @@ struct AllReduce {
} }
}; };
} // namespace tl } // namespace tl
...@@ -6,8 +6,7 @@ ...@@ -6,8 +6,7 @@
namespace tl { namespace tl {
template <int panel_width> template <int panel_width> TL_DEVICE dim3 rasterization2DRow() {
TL_DEVICE dim3 rasterization2DRow() {
auto ceil_div = [](int a, int b) { return (a + b - 1) / b; }; auto ceil_div = [](int a, int b) { return (a + b - 1) / b; };
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y; const unsigned int grid_size = gridDim.x * gridDim.y;
...@@ -16,15 +15,17 @@ TL_DEVICE dim3 rasterization2DRow() { ...@@ -16,15 +15,17 @@ TL_DEVICE dim3 rasterization2DRow() {
const unsigned int panel_idx = block_idx / panel_size; const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = ceil_div(grid_size, panel_size); const unsigned int total_panel = ceil_div(grid_size, panel_size);
const unsigned int stride = const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.x; panel_idx + 1 < total_panel
const unsigned int col_idx = ? panel_width
(panel_idx & 1) ? gridDim.x - 1 - panel_offset / stride : panel_offset / stride; : (grid_size - panel_idx * panel_size) / gridDim.x;
const unsigned int col_idx = (panel_idx & 1)
? gridDim.x - 1 - panel_offset / stride
: panel_offset / stride;
const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width; const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z}; return {col_idx, row_idx, blockIdx.z};
} }
template <int panel_width> template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
TL_DEVICE dim3 rasterization2DColumn() {
auto ceil_div = [](int a, int b) { return (a + b - 1) / b; }; auto ceil_div = [](int a, int b) { return (a + b - 1) / b; };
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y; const unsigned int grid_size = gridDim.x * gridDim.y;
...@@ -33,11 +34,14 @@ TL_DEVICE dim3 rasterization2DColumn() { ...@@ -33,11 +34,14 @@ TL_DEVICE dim3 rasterization2DColumn() {
const unsigned int panel_idx = block_idx / panel_size; const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = ceil_div(grid_size, panel_size); const unsigned int total_panel = ceil_div(grid_size, panel_size);
const unsigned int stride = const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.y; panel_idx + 1 < total_panel
const unsigned int row_idx = ? panel_width
(panel_idx & 1) ? gridDim.y - 1 - panel_offset / stride : panel_offset / stride; : (grid_size - panel_idx * panel_size) / gridDim.y;
const unsigned int row_idx = (panel_idx & 1)
? gridDim.y - 1 - panel_offset / stride
: panel_offset / stride;
const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width; const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z}; return {col_idx, row_idx, blockIdx.z};
} }
} // namespace tl } // namespace tl
...@@ -31,15 +31,17 @@ namespace tvm { ...@@ -31,15 +31,17 @@ namespace tvm {
namespace tir { namespace tir {
class ClusterPlanner { class ClusterPlanner {
public: public:
static PrimFunc Substitute(PrimFunc& f) { static PrimFunc Substitute(PrimFunc &f) {
// Step 1: Collect the read region of the function // Step 1: Collect the read region of the function
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
for (const auto& [_, buffer] : f->buffer_map) { for (const auto &[_, buffer] : f->buffer_map) {
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ f->body); Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); /*body*/ f->body);
Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0]; auto reads = access[0];
BlockIdxVisitor blockIdx_visitor; BlockIdxVisitor blockIdx_visitor;
...@@ -47,20 +49,22 @@ class ClusterPlanner { ...@@ -47,20 +49,22 @@ class ClusterPlanner {
auto dom_map = blockIdx_visitor.dom_map_; auto dom_map = blockIdx_visitor.dom_map_;
// Step 2: Collect mem reuse count for clustering on each dimension. // Step 2: Collect mem reuse count for clustering on each dimension.
std::unordered_map<const IterVarNode*, size_t> mem_reuse_count; std::unordered_map<const IterVarNode *, size_t> mem_reuse_count;
for (auto iv : dom_map) mem_reuse_count[iv] = 0; for (auto iv : dom_map)
mem_reuse_count[iv] = 0;
for (const auto& buffer_region : reads) { for (const auto &buffer_region : reads) {
PrimExpr size = buffer_region->buffer->dtype.bits(); PrimExpr size = buffer_region->buffer->dtype.bits();
RegionVisitor visitor; RegionVisitor visitor;
for (const auto& range : buffer_region->region) { for (const auto &range : buffer_region->region) {
size = size * range->extent; size = size * range->extent;
visitor(range->min); visitor(range->min);
} }
size = arith::Analyzer().Simplify(size); size = arith::Analyzer().Simplify(size);
if (auto imm = size.as<IntImmNode>()) { if (auto imm = size.as<IntImmNode>()) {
for (auto iv : dom_map) { for (auto iv : dom_map) {
if (visitor.seen_.count(iv->var.get()) == 0) mem_reuse_count[iv] += imm->value; if (visitor.seen_.count(iv->var.get()) == 0)
mem_reuse_count[iv] += imm->value;
} }
} }
} }
...@@ -70,7 +74,8 @@ class ClusterPlanner { ...@@ -70,7 +74,8 @@ class ClusterPlanner {
String cluster_tag; String cluster_tag;
for (auto iv : dom_map) { for (auto iv : dom_map) {
if (auto extent = iv->dom->extent.as<IntImmNode>()) { if (auto extent = iv->dom->extent.as<IntImmNode>()) {
if (extent->value % cluster_size_ == 0 && mem_reuse_count[iv] > mem_reuse_max) { if (extent->value % cluster_size_ == 0 &&
mem_reuse_count[iv] > mem_reuse_max) {
cluster_tag = iv->thread_tag; cluster_tag = iv->thread_tag;
mem_reuse_max = mem_reuse_count[iv]; mem_reuse_max = mem_reuse_count[iv];
} }
...@@ -78,27 +83,28 @@ class ClusterPlanner { ...@@ -78,27 +83,28 @@ class ClusterPlanner {
} }
if (mem_reuse_max > 0) { if (mem_reuse_max > 0) {
cluster_tag = "clusterIdx" + String(cluster_tag.c_str() + strlen("blockIdx")); cluster_tag =
"clusterIdx" + String(cluster_tag.c_str() + strlen("blockIdx"));
return WithAttr(f, cluster_tag, Integer(cluster_size_)); return WithAttr(f, cluster_tag, Integer(cluster_size_));
} else { } else {
return f; return f;
} }
} }
private: private:
ClusterPlanner() = default; ClusterPlanner() = default;
class RegionVisitor : public ExprVisitor { class RegionVisitor : public ExprVisitor {
public: public:
RegionVisitor(){}; RegionVisitor(){};
void VisitExpr_(const VarNode* var) { seen_.insert(var); } void VisitExpr_(const VarNode *var) { seen_.insert(var); }
std::unordered_set<const VarNode*> seen_; std::unordered_set<const VarNode *> seen_;
}; };
class BlockIdxVisitor : public StmtVisitor { class BlockIdxVisitor : public StmtVisitor {
public: public:
BlockIdxVisitor(){}; BlockIdxVisitor(){};
void VisitStmt_(const AttrStmtNode* attr) final { void VisitStmt_(const AttrStmtNode *attr) final {
if (attr->attr_key == attr::thread_extent) { if (attr->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(attr->node); IterVar iv = Downcast<IterVar>(attr->node);
String tag = iv->thread_tag; String tag = iv->thread_tag;
...@@ -108,7 +114,7 @@ class ClusterPlanner { ...@@ -108,7 +114,7 @@ class ClusterPlanner {
StmtVisitor::VisitStmt_(attr); StmtVisitor::VisitStmt_(attr);
} }
/*! \brief The map from vars to blockidx extents. */ /*! \brief The map from vars to blockidx extents. */
std::unordered_set<const IterVarNode*> dom_map_; std::unordered_set<const IterVarNode *> dom_map_;
}; };
/*! \brief Currently set the plossible cluster size as 2 */ /*! \brief Currently set the plossible cluster size as 2 */
...@@ -126,8 +132,9 @@ tvm::transform::Pass ClusterPlanning() { ...@@ -126,8 +132,9 @@ tvm::transform::Pass ClusterPlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.ClusterPlanning").set_body_typed(ClusterPlanning); TVM_REGISTER_GLOBAL("tl.transform.ClusterPlanning")
} // namespace transform .set_body_typed(ClusterPlanning);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -32,10 +32,10 @@ ...@@ -32,10 +32,10 @@
#include <queue> #include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../../op/parallel.h" #include "../../op/parallel.h"
#include "../loop_partition.h" #include "../loop_partition.h"
#include "../loop_vectorize.h" #include "../loop_vectorize.h"
#include "arith/ir_mutator_with_analyzer.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -44,15 +44,15 @@ using namespace tir; ...@@ -44,15 +44,15 @@ using namespace tir;
using arith::IRMutatorWithAnalyzer; using arith::IRMutatorWithAnalyzer;
class FragmentAccessDetector : public StmtExprVisitor { class FragmentAccessDetector : public StmtExprVisitor {
public: public:
FragmentAccessDetector() = default; FragmentAccessDetector() = default;
void Collect(Stmt stmt) { VisitStmt(stmt); } void Collect(Stmt stmt) { VisitStmt(stmt); }
bool HasFragmentAccess() { return has_fragment_access_; } bool HasFragmentAccess() { return has_fragment_access_; }
private: private:
void VisitExpr_(const BufferLoadNode* op) final { void VisitExpr_(const BufferLoadNode *op) final {
// Check if the buffer is in global scope // Check if the buffer is in global scope
if (IsFragmentBuffer(op->buffer)) { if (IsFragmentBuffer(op->buffer)) {
has_fragment_access_ = true; has_fragment_access_ = true;
...@@ -60,7 +60,7 @@ class FragmentAccessDetector : public StmtExprVisitor { ...@@ -60,7 +60,7 @@ class FragmentAccessDetector : public StmtExprVisitor {
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
void VisitStmt_(const BufferStoreNode* op) final { void VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope // Check if the buffer is in global scope
if (IsFragmentBuffer(op->buffer)) { if (IsFragmentBuffer(op->buffer)) {
has_fragment_access_ = true; has_fragment_access_ = true;
...@@ -69,8 +69,9 @@ class FragmentAccessDetector : public StmtExprVisitor { ...@@ -69,8 +69,9 @@ class FragmentAccessDetector : public StmtExprVisitor {
} }
// Helper function to determine if a buffer is local.fragment // Helper function to determine if a buffer is local.fragment
bool IsFragmentBuffer(const Buffer& buffer) { bool IsFragmentBuffer(const Buffer &buffer) {
// The storage scope is often encoded in the buffer->data var name or associated attributes. // The storage scope is often encoded in the buffer->data var name or
// associated attributes.
String scope = buffer.scope(); String scope = buffer.scope();
return scope == "local.fragment"; return scope == "local.fragment";
} }
...@@ -87,23 +88,25 @@ class FragmentAccessDetector : public StmtExprVisitor { ...@@ -87,23 +88,25 @@ class FragmentAccessDetector : public StmtExprVisitor {
* Once fused, a single loop variable will replace the chain, and the * Once fused, a single loop variable will replace the chain, and the
* original loop variables will be derived by division and modulo operations. * original loop variables will be derived by division and modulo operations.
* *
* This can be helpful for inferring layout for the fragment in a subsequent pass. * This can be helpful for inferring layout for the fragment in a subsequent
* pass.
*/ */
class ParallelLoopFuser : public IRMutatorWithAnalyzer { class ParallelLoopFuser : public IRMutatorWithAnalyzer {
public: public:
static Stmt Fuse(Stmt stmt) { static Stmt Fuse(Stmt stmt) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
ParallelLoopFuser substituter(&analyzer); ParallelLoopFuser substituter(&analyzer);
return substituter.VisitStmt(stmt); return substituter.VisitStmt(stmt);
} }
private: private:
ParallelLoopFuser(arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {}; ParallelLoopFuser(arith::Analyzer *analyzer)
: IRMutatorWithAnalyzer(analyzer){};
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode *op) final {
// Gather consecutive parallel loops // Gather consecutive parallel loops
std::vector<const ForNode*> loop_chain; std::vector<const ForNode *> loop_chain;
const ForNode* current = op; const ForNode *current = op;
// check if has fragment access // check if has fragment access
FragmentAccessDetector detector; FragmentAccessDetector detector;
detector.Collect(op->body); detector.Collect(op->body);
...@@ -113,11 +116,13 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer { ...@@ -113,11 +116,13 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
} }
while (true) { while (true) {
if (current->kind != ForKind::kParallel) break; if (current->kind != ForKind::kParallel)
if (!is_zero(current->min)) break; break;
if (!is_zero(current->min))
break;
loop_chain.push_back(current); loop_chain.push_back(current);
const ForNode* inner_for = current->body.as<ForNode>(); const ForNode *inner_for = current->body.as<ForNode>();
if (!inner_for) { if (!inner_for) {
break; break;
} }
...@@ -147,7 +152,7 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer { ...@@ -147,7 +152,7 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
Var fused_var(fused_name, DataType::Int(32)); Var fused_var(fused_name, DataType::Int(32));
// The body of the last loop in the chain: // The body of the last loop in the chain:
const ForNode* innermost_loop = loop_chain.back(); const ForNode *innermost_loop = loop_chain.back();
Stmt body = innermost_loop->body; Stmt body = innermost_loop->body;
// We need to substitute all loop variables in the chain. // We need to substitute all loop variables in the chain.
...@@ -175,7 +180,8 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer { ...@@ -175,7 +180,8 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
extents.push_back(l->extent); extents.push_back(l->extent);
} }
std::vector<PrimExpr> strides(loop_chain.size(), make_const(DataType::Int(32), 1)); std::vector<PrimExpr> strides(loop_chain.size(),
make_const(DataType::Int(32), 1));
for (int i = static_cast<int>(loop_chain.size()) - 2; i >= 0; i--) { for (int i = static_cast<int>(loop_chain.size()) - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * extents[i + 1]; strides[i] = strides[i + 1] * extents[i + 1];
} }
...@@ -189,8 +195,9 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer { ...@@ -189,8 +195,9 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
Map<Var, PrimExpr> var_map; Map<Var, PrimExpr> var_map;
for (size_t i = 0; i < loop_chain.size(); i++) { for (size_t i = 0; i < loop_chain.size(); i++) {
const ForNode* loop = loop_chain[i]; const ForNode *loop = loop_chain[i];
var_map.Set(loop->loop_var, analyzer_->Simplify(create_index_expr(static_cast<int>(i)))); var_map.Set(loop->loop_var,
analyzer_->Simplify(create_index_expr(static_cast<int>(i))));
} }
// Perform the substitution // Perform the substitution
...@@ -203,5 +210,5 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer { ...@@ -203,5 +210,5 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
} }
}; };
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -32,10 +32,10 @@ ...@@ -32,10 +32,10 @@
#include <queue> #include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../../op/parallel.h" #include "../../op/parallel.h"
#include "../loop_partition.h" #include "../loop_partition.h"
#include "../loop_vectorize.h" #include "../loop_vectorize.h"
#include "arith/ir_mutator_with_analyzer.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -46,7 +46,8 @@ using namespace tir; ...@@ -46,7 +46,8 @@ using namespace tir;
// Use the same code as tir.transform.vectorize_loop // Use the same code as tir.transform.vectorize_loop
inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) { inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
if (is_scalable) { if (is_scalable) {
return Mul(Call(DataType::Int(32), builtin::vscale(), {}), lanes_or_vscale_factor); return Mul(Call(DataType::Int(32), builtin::vscale(), {}),
lanes_or_vscale_factor);
} else { } else {
return lanes_or_vscale_factor; return lanes_or_vscale_factor;
} }
...@@ -58,7 +59,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { ...@@ -58,7 +59,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
e.dtype().is_scalable_vector() == is_scalable) e.dtype().is_scalable_vector() == is_scalable)
return e; return e;
if (const BroadcastNode* op = e.as<BroadcastNode>()) { if (const BroadcastNode *op = e.as<BroadcastNode>()) {
ICHECK(op->dtype.is_scalable_vector() == is_scalable) ICHECK(op->dtype.is_scalable_vector() == is_scalable)
<< "Can't broadcast between scalable and fixed length vectors."; << "Can't broadcast between scalable and fixed length vectors.";
int e_lanes = op->dtype.get_lanes_or_vscale_factor(); int e_lanes = op->dtype.get_lanes_or_vscale_factor();
...@@ -68,40 +69,39 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { ...@@ -68,40 +69,39 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
} }
} }
ICHECK(e.dtype().is_scalar()) << "Cannot broadcast lanes=" ICHECK(e.dtype().is_scalar())
<< e.dtype().get_lanes_or_vscale_factor() << "Cannot broadcast lanes=" << e.dtype().get_lanes_or_vscale_factor()
<< " is_scalable=" << e.dtype().is_scalable_vector() << " to " << " is_scalable=" << e.dtype().is_scalable_vector() << " to " << lanes;
<< lanes;
return Broadcast(e, CreateNewLanes(is_scalable, lanes)); return Broadcast(e, CreateNewLanes(is_scalable, lanes));
} }
// Rewrite vectorized allocation access // Rewrite vectorized allocation access
// This is necessary for making each vector component containing its own workspace. // This is necessary for making each vector component containing its own
// Originates from Halide's loop vectorizer // workspace. Originates from Halide's loop vectorizer
// //
// s[i] = s[i * lanes + var] // s[i] = s[i * lanes + var]
// //
// The same principle applies when using one thread to simulate multiple context. // The same principle applies when using one thread to simulate multiple
// context.
// //
class VecAllocAccess : public StmtExprMutator { class VecAllocAccess : public StmtExprMutator {
public: public:
VecAllocAccess(const VarNode* buf, Var var, PrimExpr var_lanes) VecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {} : buf_(buf), var_(var), var_lanes_(var_lanes) {}
PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return UpdateBufferAccess(load); return UpdateBufferAccess(load);
} }
Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
return UpdateBufferAccess(store); return UpdateBufferAccess(store);
} }
private: private:
template <typename Node> template <typename Node> Node UpdateBufferAccess(Node node) {
Node UpdateBufferAccess(Node node) {
// Only update the buffer that's being replaced. // Only update the buffer that's being replaced.
if (node->buffer->data.get() != buf_) { if (node->buffer->data.get() != buf_) {
return node; return node;
...@@ -117,7 +117,8 @@ class VecAllocAccess : public StmtExprMutator { ...@@ -117,7 +117,8 @@ class VecAllocAccess : public StmtExprMutator {
// var_lanes_. Typically, this will be a 1-d index into a flat // var_lanes_. Typically, this will be a 1-d index into a flat
// memory space. // memory space.
Array<PrimExpr> shape = node->buffer->shape; Array<PrimExpr> shape = node->buffer->shape;
shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_)); shape.Set(shape.size() - 1,
analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_));
// TODO(Lunderberg): Move this pass to be prior to // TODO(Lunderberg): Move this pass to be prior to
// StorageFlatten/FlattenBuffer, implement by appending a // StorageFlatten/FlattenBuffer, implement by appending a
...@@ -146,8 +147,9 @@ class VecAllocAccess : public StmtExprMutator { ...@@ -146,8 +147,9 @@ class VecAllocAccess : public StmtExprMutator {
// Extend the last index by the number of lanes in the vectorized // Extend the last index by the number of lanes in the vectorized
// variable. // variable.
Array<PrimExpr> indices = node->indices; Array<PrimExpr> indices = node->indices;
indices.Set(indices.size() - 1, indices.Set(
analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); indices.size() - 1,
analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_));
auto writer = node.CopyOnWrite(); auto writer = node.CopyOnWrite();
writer->buffer = buf; writer->buffer = buf;
...@@ -156,9 +158,9 @@ class VecAllocAccess : public StmtExprMutator { ...@@ -156,9 +158,9 @@ class VecAllocAccess : public StmtExprMutator {
} }
// buffer var // buffer var
const VarNode* buf_; const VarNode *buf_;
// Updated buffer objects. // Updated buffer objects.
std::unordered_map<const BufferNode*, Buffer> buffer_map_; std::unordered_map<const BufferNode *, Buffer> buffer_map_;
// variable to be replaced // variable to be replaced
Var var_; Var var_;
// the lanes. // the lanes.
...@@ -170,8 +172,9 @@ class VecAllocAccess : public StmtExprMutator { ...@@ -170,8 +172,9 @@ class VecAllocAccess : public StmtExprMutator {
// We use ExprFunctor directly instead of StmtExprMutator // We use ExprFunctor directly instead of StmtExprMutator
// This is because the transformation can change the dtype of the Expr // This is because the transformation can change the dtype of the Expr
// The existing ExprMutator transformation rules may not be well defined. // The existing ExprMutator transformation rules may not be well defined.
class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExpr&)> { class Vectorizer : public StmtMutator,
public: public ExprFunctor<PrimExpr(const PrimExpr &)> {
public:
using ExprFunctor::VisitExpr; using ExprFunctor::VisitExpr;
using StmtMutator::operator(); using StmtMutator::operator();
...@@ -179,7 +182,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -179,7 +182,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes); ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
} }
Stmt VisitStmt(const Stmt& stmt) final { Stmt VisitStmt(const Stmt &stmt) final {
ICHECK(!need_scalarize_); ICHECK(!need_scalarize_);
Stmt ret = StmtMutator::VisitStmt(stmt); Stmt ret = StmtMutator::VisitStmt(stmt);
if (need_scalarize_) { if (need_scalarize_) {
...@@ -190,17 +193,19 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -190,17 +193,19 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
} }
} }
PrimExpr VisitExpr(const PrimExpr& e) final { return ExprFunctor::VisitExpr(e); } PrimExpr VisitExpr(const PrimExpr &e) final {
return ExprFunctor::VisitExpr(e);
}
PrimExpr VisitExpr_(const AddNode* op) final { PrimExpr VisitExpr_(const AddNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; }); return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
} }
PrimExpr VisitExpr_(const SubNode* op) final { PrimExpr VisitExpr_(const SubNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; }); return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; });
} }
PrimExpr VisitExpr_(const MulNode* op) final { PrimExpr VisitExpr_(const MulNode *op) final {
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
...@@ -211,11 +216,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -211,11 +216,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if (is_vec_a && is_vec_b) { if (is_vec_a && is_vec_b) {
// Let's not multiply scalable and fixed length vectors // Let's not multiply scalable and fixed length vectors
ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector()) ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector())
<< "Fixed length and scalable vectors can't be mixed in multiplication."; << "Fixed length and scalable vectors can't be mixed in "
"multiplication.";
} }
if (is_vec_a || is_vec_b) { if (is_vec_a || is_vec_b) {
const RampNode* b_ramp = b.as<RampNode>(); const RampNode *b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>(); const RampNode *a_ramp = a.as<RampNode>();
if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) { if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) {
PrimExpr lanes = a_ramp->lanes; PrimExpr lanes = a_ramp->lanes;
return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes); return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes);
...@@ -227,28 +233,34 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -227,28 +233,34 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int max_lanes = std::max(a_lanes, b_lanes); int max_lanes = std::max(a_lanes, b_lanes);
bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); bool is_scalable =
return Mul(BroadcastTo(a, max_lanes, is_scalable), BroadcastTo(b, max_lanes, is_scalable)); a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return Mul(BroadcastTo(a, max_lanes, is_scalable),
BroadcastTo(b, max_lanes, is_scalable));
} }
} }
return BinaryVec<Mul>(op); return BinaryVec<Mul>(op);
} }
PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec<Div>(op); } PrimExpr VisitExpr_(const DivNode *op) final { return BinaryVec<Div>(op); }
PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec<Mod>(op); } PrimExpr VisitExpr_(const ModNode *op) final { return BinaryVec<Mod>(op); }
PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec<FloorDiv>(op); } PrimExpr VisitExpr_(const FloorDivNode *op) final {
PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec<FloorMod>(op); } return BinaryVec<FloorDiv>(op);
PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec<Min>(op); } }
PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec<Max>(op); } PrimExpr VisitExpr_(const FloorModNode *op) final {
PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec<EQ>(op); } return BinaryVec<FloorMod>(op);
PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec<NE>(op); } }
PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec<LT>(op); } PrimExpr VisitExpr_(const MinNode *op) final { return BinaryVec<Min>(op); }
PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec<LE>(op); } PrimExpr VisitExpr_(const MaxNode *op) final { return BinaryVec<Max>(op); }
PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec<GT>(op); } PrimExpr VisitExpr_(const EQNode *op) final { return BinaryVec<EQ>(op); }
PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec<GE>(op); } PrimExpr VisitExpr_(const NENode *op) final { return BinaryVec<NE>(op); }
PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec<And>(op); } PrimExpr VisitExpr_(const LTNode *op) final { return BinaryVec<LT>(op); }
PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec<Or>(op); } PrimExpr VisitExpr_(const LENode *op) final { return BinaryVec<LE>(op); }
PrimExpr VisitExpr_(const GTNode *op) final { return BinaryVec<GT>(op); }
PrimExpr VisitExpr_(const NotNode* op) final { PrimExpr VisitExpr_(const GENode *op) final { return BinaryVec<GE>(op); }
PrimExpr VisitExpr_(const AndNode *op) final { return BinaryVec<And>(op); }
PrimExpr VisitExpr_(const OrNode *op) final { return BinaryVec<Or>(op); }
PrimExpr VisitExpr_(const NotNode *op) final {
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) { if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op); return GetRef<PrimExpr>(op);
...@@ -257,7 +269,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -257,7 +269,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
} }
} }
PrimExpr VisitExpr_(const RampNode* op) final { PrimExpr VisitExpr_(const RampNode *op) final {
PrimExpr base = this->VisitExpr(op->base); PrimExpr base = this->VisitExpr(op->base);
PrimExpr stride = this->VisitExpr(op->stride); PrimExpr stride = this->VisitExpr(op->stride);
ICHECK(!base.dtype().is_scalable_vector()) ICHECK(!base.dtype().is_scalable_vector())
...@@ -267,11 +279,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -267,11 +279,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) { if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) {
ICHECK(op->lanes->IsInstance<IntImmNode>()) ICHECK(op->lanes->IsInstance<IntImmNode>())
<< "Vectorizing over existing scalable vectors is not supported."; << "Vectorizing over existing scalable vectors is not supported.";
const RampNode* base_ramp = base.as<RampNode>(); const RampNode *base_ramp = base.as<RampNode>();
int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
int base_ramp_lanes = static_cast<int>(Downcast<IntImm>(base_ramp->lanes)->value); int base_ramp_lanes =
static_cast<int>(Downcast<IntImm>(base_ramp->lanes)->value);
if (analyzer_.CanProve(base_ramp->stride == if (analyzer_.CanProve(base_ramp->stride ==
stride * make_const(stride.dtype(), base_ramp_lanes))) { stride *
make_const(stride.dtype(), base_ramp_lanes))) {
return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes); return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes);
} }
} }
...@@ -280,13 +294,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -280,13 +294,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
stride = BroadcastTo(stride, lanes, false); stride = BroadcastTo(stride, lanes, false);
Array<PrimExpr> elems; Array<PrimExpr> elems;
for (int i = 0; i < lanes; ++i) { for (int i = 0; i < lanes; ++i) {
elems.push_back( elems.push_back(Ramp(Shuffle::ExtractElement(base, i),
Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes)); Shuffle::ExtractElement(stride, i), op->lanes));
} }
return Shuffle::Concat(elems); return Shuffle::Concat(elems);
} }
PrimExpr VisitExpr_(const BroadcastNode* op) final { PrimExpr VisitExpr_(const BroadcastNode *op) final {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) { if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
...@@ -299,45 +313,56 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -299,45 +313,56 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
} }
} }
PrimExpr VisitExpr_(const SelectNode* op) final { PrimExpr VisitExpr_(const SelectNode *op) final {
PrimExpr cond = this->VisitExpr(op->condition); PrimExpr cond = this->VisitExpr(op->condition);
PrimExpr t = this->VisitExpr(op->true_value); PrimExpr t = this->VisitExpr(op->true_value);
PrimExpr f = this->VisitExpr(op->false_value); PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op); return GetRef<PrimExpr>(op);
} else { } else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes); int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes);
bool is_scalable = cond.dtype().is_scalable_vector() || t.dtype().is_scalable_vector() || bool is_scalable = cond.dtype().is_scalable_vector() ||
t.dtype().is_scalable_vector() ||
f.dtype().is_scalable_vector(); f.dtype().is_scalable_vector();
return Select(BroadcastTo(cond, lanes, is_scalable), BroadcastTo(t, lanes, is_scalable), return Select(BroadcastTo(cond, lanes, is_scalable),
BroadcastTo(t, lanes, is_scalable),
BroadcastTo(f, lanes, is_scalable)); BroadcastTo(f, lanes, is_scalable));
} }
} }
PrimExpr VisitExpr_(const CastNode* op) final { PrimExpr VisitExpr_(const CastNode *op) final {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op); return GetRef<PrimExpr>(op);
} else { } else {
if (value.dtype().is_scalable_vector()) { if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor(value.dtype().vscale_factor()), value); return Cast(op->dtype.with_scalable_vscale_factor(
value.dtype().vscale_factor()),
value);
} else { } else {
return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
} }
} }
} }
PrimExpr VisitExpr_(const FloatImmNode* op) final { return GetRef<PrimExpr>(op); } PrimExpr VisitExpr_(const FloatImmNode *op) final {
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode* op) final { return GetRef<PrimExpr>(op); } PrimExpr VisitExpr_(const IntImmNode *op) final {
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const StringImmNode* op) final { return GetRef<PrimExpr>(op); } PrimExpr VisitExpr_(const StringImmNode *op) final {
return GetRef<PrimExpr>(op);
}
// Variable // Variable
PrimExpr VisitExpr_(const VarNode* op) final { PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op); Var var = GetRef<Var>(op);
if (var.same_as(var_)) { if (var.same_as(var_)) {
...@@ -351,7 +376,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -351,7 +376,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
} }
} }
// IfThenElse expr // IfThenElse expr
PrimExpr MutateIfThenElseExpr_(const CallNode* op) { PrimExpr MutateIfThenElseExpr_(const CallNode *op) {
PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) { if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
...@@ -359,24 +384,27 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -359,24 +384,27 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
} }
PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]); PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op); return GetRef<PrimExpr>(op);
} else { } else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(t_lanes, f_lanes); int lanes = std::max(t_lanes, f_lanes);
bool is_scalable = t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector(); bool is_scalable =
t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector();
t = BroadcastTo(t, lanes, is_scalable); t = BroadcastTo(t, lanes, is_scalable);
f = BroadcastTo(f, lanes, is_scalable); f = BroadcastTo(f, lanes, is_scalable);
if (is_scalable) { if (is_scalable) {
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {cond, t, f}); return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{cond, t, f});
} else { } else {
return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
} }
} }
} }
// Reinterpret expr // Reinterpret expr
PrimExpr MutateReinterpretExpr_(const CallNode* op) { PrimExpr MutateReinterpretExpr_(const CallNode *op) {
ICHECK(op->op.same_as(builtin::reinterpret())); ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]); PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) { if (value.same_as(op->args[0])) {
...@@ -384,14 +412,15 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -384,14 +412,15 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
} else { } else {
int lanes = value.dtype().get_lanes_or_vscale_factor(); int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) { if (value.dtype().is_scalable_vector()) {
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {value}); return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{value});
} else { } else {
return Call(op->dtype.with_lanes(lanes), op->op, {value}); return Call(op->dtype.with_lanes(lanes), op->op, {value});
} }
} }
} }
// Call // Call
PrimExpr VisitExpr_(const CallNode* op) final { PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::if_then_else())) { if (op->op.same_as(builtin::if_then_else())) {
return MutateIfThenElseExpr_(op); return MutateIfThenElseExpr_(op);
} else if (op->op.same_as(builtin::texture2d_load())) { } else if (op->op.same_as(builtin::texture2d_load())) {
...@@ -406,13 +435,15 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -406,13 +435,15 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
// Vectorize the value to store // Vectorize the value to store
Array<PrimExpr> value{op->args.back()}; Array<PrimExpr> value{op->args.back()};
Array<PrimExpr> mutated_value = MutateArray(value, &lane); Array<PrimExpr> mutated_value = MutateArray(value, &lane);
Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]}; Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2],
mutated_value[0]};
return Call(op->dtype.with_lanes(lane), op->op, new_args); return Call(op->dtype.with_lanes(lane), op->op, new_args);
} else if (op->op.same_as(builtin::reinterpret())) { } else if (op->op.same_as(builtin::reinterpret())) {
return MutateReinterpretExpr_(op); return MutateReinterpretExpr_(op);
} }
auto optional_op = op->op.as<Op>(); auto optional_op = op->op.as<Op>();
bool vectorizable = optional_op && op_vectorizable_.get(optional_op.value(), false) && bool vectorizable = optional_op &&
op_vectorizable_.get(optional_op.value(), false) &&
!op->dtype.is_scalable_vector(); !op->dtype.is_scalable_vector();
if (!vectorizable) { if (!vectorizable) {
...@@ -443,10 +474,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -443,10 +474,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
} }
} }
// BufferLoad // BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = GetRef<BufferLoad>(op); auto load = GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
};
Array<PrimExpr> indices = op->indices.Map(fmutate); Array<PrimExpr> indices = op->indices.Map(fmutate);
if (!indices.same_as(op->indices)) { if (!indices.same_as(op->indices)) {
...@@ -457,7 +490,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -457,7 +490,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return std::move(load); return std::move(load);
} }
// Let // Let
PrimExpr VisitExpr_(const LetNode* op) final { PrimExpr VisitExpr_(const LetNode *op) final {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
// Weaker SSA condition // Weaker SSA condition
// A single var can be binded in multiple lets // A single var can be binded in multiple lets
...@@ -486,24 +519,28 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -486,24 +519,28 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
} }
} }
// BufferStore // BufferStore
Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = GetRef<BufferStore>(op); auto store = GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
};
Array<PrimExpr> indices = op->indices.Map(fmutate); Array<PrimExpr> indices = op->indices.Map(fmutate);
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (!indices.same_as(op->indices) || !value.same_as(op->value)) { if (!indices.same_as(op->indices) || !value.same_as(op->value)) {
ICHECK(!op->buffer->dtype.is_scalable_vector()) ICHECK(!op->buffer->dtype.is_scalable_vector())
<< "Vectorizing over scalable buffer elements is not supported in vectorizer."; << "Vectorizing over scalable buffer elements is not supported in "
"vectorizer.";
// How many lanes of indexing are present in the index and // How many lanes of indexing are present in the index and
// buffer element type, excluding the last index. // buffer element type, excluding the last index.
int other_index_lanes = op->buffer->dtype.lanes(); int other_index_lanes = op->buffer->dtype.lanes();
for (size_t i = 0; i < indices.size() - 1; i++) { for (size_t i = 0; i < indices.size() - 1; i++) {
other_index_lanes *= indices[i].dtype().lanes(); other_index_lanes *= indices[i].dtype().lanes();
// Only allow the last index to be scalable // Only allow the last index to be scalable
ICHECK(!indices[i].dtype().is_scalable_vector()) << "Only the last index can be scalable."; ICHECK(!indices[i].dtype().is_scalable_vector())
<< "Only the last index can be scalable.";
} }
// The total number of lanes of indexing, including the last index. // The total number of lanes of indexing, including the last index.
...@@ -519,14 +556,16 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -519,14 +556,16 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int total_lanes = std::max(index_lanes, value_dtype_lanes); int total_lanes = std::max(index_lanes, value_dtype_lanes);
ICHECK_EQ(total_lanes % other_index_lanes, 0) ICHECK_EQ(total_lanes % other_index_lanes, 0)
<< "When storing to buffer " << op->buffer->name << ", cannot produce " << total_lanes << "When storing to buffer " << op->buffer->name
<< ", cannot produce " << total_lanes
<< " lanes of storage location by changing the last index."; << " lanes of storage location by changing the last index.";
int last_index_lanes = total_lanes / other_index_lanes; int last_index_lanes = total_lanes / other_index_lanes;
// Broadcast the last index such that the total number of index // Broadcast the last index such that the total number of index
// lanes matches the desired number. // lanes matches the desired number.
indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1], last_index_lanes, indices.Set(indices.size() - 1,
is_last_index_scalable)); BroadcastTo(indices[indices.size() - 1], last_index_lanes,
is_last_index_scalable));
auto writer = store.CopyOnWrite(); auto writer = store.CopyOnWrite();
writer->indices = indices; writer->indices = indices;
...@@ -536,7 +575,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -536,7 +575,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return std::move(store); return std::move(store);
} }
// For // For
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kVectorized) { if (op->kind == ForKind::kVectorized) {
LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring..."; LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
} }
...@@ -550,12 +589,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -550,12 +589,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if (extent.same_as(op->extent) && body.same_as(op->body)) { if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op); return GetRef<Stmt>(op);
} else { } else {
return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding, return For(op->loop_var, op->min, extent, op->kind, body,
op->annotations); op->thread_binding, op->annotations);
} }
} }
// IfThenElse // IfThenElse
Stmt VisitStmt_(const IfThenElseNode* op) final { Stmt VisitStmt_(const IfThenElseNode *op) final {
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition); PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) { if (condition.dtype().is_scalable_or_fixed_length_vector()) {
...@@ -574,13 +613,14 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -574,13 +613,14 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
} }
} }
// While // While
Stmt VisitStmt_(const WhileNode* op) final { Stmt VisitStmt_(const WhileNode *op) final {
LOG(FATAL) << "A while loop inside a vectorized loop not supported."; LOG(FATAL) << "A while loop inside a vectorized loop not supported.";
} }
// LetStmt // LetStmt
Stmt VisitStmt_(const LetStmtNode* op) final { Stmt VisitStmt_(const LetStmtNode *op) final {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice"; ICHECK(!let_binding_.count(op->var))
<< "SSA violation, a single var is binded twice";
let_binding_[op->var] = value; let_binding_[op->var] = value;
if (value.dtype().get_lanes_or_vscale_factor() != if (value.dtype().get_lanes_or_vscale_factor() !=
...@@ -599,20 +639,22 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -599,20 +639,22 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
} }
} }
// Allocate // Allocate
Stmt VisitStmt_(const AllocateNode* op) final { Stmt VisitStmt_(const AllocateNode *op) final {
// Mutate the condition // Mutate the condition
PrimExpr condition = this->VisitExpr(op->condition); PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) { if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op)); return Scalarize(GetRef<Stmt>(op));
} }
// Mutate the extents // Mutate the extents
Array<PrimExpr> extents; Array<PrimExpr> extents;
for (const auto& extent : op->extents) { for (const auto &extent : op->extents) {
PrimExpr new_ext = this->VisitExpr(extent); PrimExpr new_ext = this->VisitExpr(extent);
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) { if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op)); return Scalarize(GetRef<Stmt>(op));
} }
extents.push_back(new_ext); extents.push_back(new_ext);
...@@ -629,7 +671,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -629,7 +671,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_); extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_);
// Rewrite access to the buffer in the body. // Rewrite access to the buffer in the body.
Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); Stmt body =
VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
body = this->VisitStmt(body); body = this->VisitStmt(body);
return Allocate(op->buffer_var, op->dtype, extents, condition, body); return Allocate(op->buffer_var, op->dtype, extents, condition, body);
} }
...@@ -641,11 +684,11 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -641,11 +684,11 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt); return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
} }
// ProducerStore // ProducerStore
Stmt VisitStmt_(const ProducerStoreNode* op) final { Stmt VisitStmt_(const ProducerStoreNode *op) final {
LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc"; LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc";
} }
private: private:
// analyzer // analyzer
arith::Analyzer analyzer_; arith::Analyzer analyzer_;
// deep equal // deep equal
...@@ -661,19 +704,22 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -661,19 +704,22 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
// Let binding // Let binding
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_; std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// vectorizable property // vectorizable property
OpAttrMap<TVectorizable> op_vectorizable_ = Op::GetAttrMap<TVectorizable>("TVectorizable"); OpAttrMap<TVectorizable> op_vectorizable_ =
Op::GetAttrMap<TVectorizable>("TVectorizable");
// mutate array, with given lane requirement // mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement. // when finished, p_lane updates the lane requirement.
Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int* p_lanes) { Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int *p_lanes) {
if (arr.size() == 0) return arr; if (arr.size() == 0)
int& lanes = *p_lanes; return arr;
int &lanes = *p_lanes;
bool changed = false; bool changed = false;
std::vector<PrimExpr> new_arr(arr.size()); std::vector<PrimExpr> new_arr(arr.size());
for (size_t i = 0; i < arr.size(); i++) { for (size_t i = 0; i < arr.size(); i++) {
PrimExpr old_elem = arr[i]; PrimExpr old_elem = arr[i];
PrimExpr new_elem = this->VisitExpr(old_elem); PrimExpr new_elem = this->VisitExpr(old_elem);
if (!new_elem.same_as(old_elem)) changed = true; if (!new_elem.same_as(old_elem))
changed = true;
new_arr[i] = new_elem; new_arr[i] = new_elem;
lanes = std::max(lanes, new_elem.dtype().lanes()); lanes = std::max(lanes, new_elem.dtype().lanes());
} }
...@@ -684,12 +730,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -684,12 +730,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
changed = true; changed = true;
} }
} }
if (!changed) return arr; if (!changed)
return arr;
return Array<PrimExpr>(new_arr); return Array<PrimExpr>(new_arr);
} }
template <typename TOp, typename T> template <typename TOp, typename T> PrimExpr BinaryVec(const T *op) {
PrimExpr BinaryVec(const T* op) { static_assert(std::is_same<typename TOp::ContainerType, T>::value,
static_assert(std::is_same<typename TOp::ContainerType, T>::value, "constraint"); "constraint");
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
...@@ -698,12 +745,14 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -698,12 +745,14 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(a_lanes, b_lanes); int lanes = std::max(a_lanes, b_lanes);
bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); bool is_scalable =
return TOp(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return TOp(BroadcastTo(a, lanes, is_scalable),
BroadcastTo(b, lanes, is_scalable));
} }
} }
template <typename T, typename FCompute> template <typename T, typename FCompute>
PrimExpr AddSubVec(const T* op, FCompute fcompute) { PrimExpr AddSubVec(const T *op, FCompute fcompute) {
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
...@@ -713,21 +762,25 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp ...@@ -713,21 +762,25 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(a_lanes, b_lanes); int lanes = std::max(a_lanes, b_lanes);
if (lanes != 1) { if (lanes != 1) {
const RampNode* b_ramp = b.as<RampNode>(); const RampNode *b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>(); const RampNode *a_ramp = a.as<RampNode>();
if (a.dtype().is_scalar() && b_ramp) { if (a.dtype().is_scalar() && b_ramp) {
return Ramp(fcompute(a, b_ramp->base), return Ramp(
fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes); fcompute(a, b_ramp->base),
fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride),
b_ramp->lanes);
} }
if (b.dtype().is_scalar() && a_ramp) { if (b.dtype().is_scalar() && a_ramp) {
return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
} }
} }
bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); bool is_scalable =
return fcompute(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return fcompute(BroadcastTo(a, lanes, is_scalable),
BroadcastTo(b, lanes, is_scalable));
} }
} }
}; };
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file \ No newline at end of file
...@@ -34,19 +34,19 @@ namespace tl { ...@@ -34,19 +34,19 @@ namespace tl {
using namespace tir; using namespace tir;
class FrontendLegalizer : public arith::IRMutatorWithAnalyzer { class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
public: public:
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
FrontendLegalizer substituter(&analyzer); FrontendLegalizer substituter(&analyzer);
PrimFuncNode* fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
return f; return f;
} }
private: private:
using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer; using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
Stmt VisitStmt_(const ForNode* node) final { Stmt VisitStmt_(const ForNode *node) final {
if (node->kind == ForKind::kParallel) { if (node->kind == ForKind::kParallel) {
parallel_for_scope_++; parallel_for_scope_++;
} }
...@@ -57,7 +57,7 @@ class FrontendLegalizer : public arith::IRMutatorWithAnalyzer { ...@@ -57,7 +57,7 @@ class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
return n; return n;
} }
PrimExpr VisitExpr_(const VarNode* node) final { PrimExpr VisitExpr_(const VarNode *node) final {
if (let_bindings_.count(node)) { if (let_bindings_.count(node)) {
return arith::IRMutatorWithAnalyzer::VisitExpr(let_bindings_[node]); return arith::IRMutatorWithAnalyzer::VisitExpr(let_bindings_[node]);
} else { } else {
...@@ -65,18 +65,18 @@ class FrontendLegalizer : public arith::IRMutatorWithAnalyzer { ...@@ -65,18 +65,18 @@ class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
} }
} }
Stmt VisitStmt_(const LetStmtNode* node) final { Stmt VisitStmt_(const LetStmtNode *node) final {
let_bindings_[node->var.get()] = node->value; let_bindings_[node->var.get()] = node->value;
return arith::IRMutatorWithAnalyzer::VisitStmt(node->body); return arith::IRMutatorWithAnalyzer::VisitStmt(node->body);
} }
PrimExpr VisitExpr_(const LetNode* node) final { PrimExpr VisitExpr_(const LetNode *node) final {
let_bindings_[node->var.get()] = node->value; let_bindings_[node->var.get()] = node->value;
return arith::IRMutatorWithAnalyzer::VisitExpr(node->body); return arith::IRMutatorWithAnalyzer::VisitExpr(node->body);
} }
int parallel_for_scope_ = 0; int parallel_for_scope_ = 0;
std::unordered_map<const VarNode*, PrimExpr> let_bindings_; std::unordered_map<const VarNode *, PrimExpr> let_bindings_;
}; };
using namespace tir::transform; using namespace tir::transform;
...@@ -91,5 +91,5 @@ Pass FrontendLegalize() { ...@@ -91,5 +91,5 @@ Pass FrontendLegalize() {
TVM_REGISTER_GLOBAL("tl.transform.FrontendLegalize") TVM_REGISTER_GLOBAL("tl.transform.FrontendLegalize")
.set_body_typed(FrontendLegalize); .set_body_typed(FrontendLegalize);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -38,10 +38,10 @@ using namespace tir; ...@@ -38,10 +38,10 @@ using namespace tir;
enum class Proxy { kGeneric, kAsync, kBoth }; enum class Proxy { kGeneric, kAsync, kBoth };
class ProxyMarker : public StmtVisitor { class ProxyMarker : public StmtVisitor {
public: public:
ProxyMarker() = default; ProxyMarker() = default;
Proxy GetProxy(const StmtNode* stmt) const { Proxy GetProxy(const StmtNode *stmt) const {
auto it = map_.find(stmt); auto it = map_.find(stmt);
// ICHECK(it != map_.end()); // ICHECK(it != map_.end());
// TODO: This is a hack implementation to avoid the ICHECK failure. // TODO: This is a hack implementation to avoid the ICHECK failure.
...@@ -51,9 +51,9 @@ class ProxyMarker : public StmtVisitor { ...@@ -51,9 +51,9 @@ class ProxyMarker : public StmtVisitor {
return it->second; return it->second;
} }
Proxy GetProxy(const Stmt& stmt) const { return GetProxy(stmt.get()); } Proxy GetProxy(const Stmt &stmt) const { return GetProxy(stmt.get()); }
void VisitStmt_(const EvaluateNode* op) final { void VisitStmt_(const EvaluateNode *op) final {
Proxy proxy = Proxy::kAsync; Proxy proxy = Proxy::kAsync;
if (auto call = op->value.as<CallNode>()) { if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(LDMatrixOp()) || call->op.same_as(STMatrixOp())) { if (call->op.same_as(LDMatrixOp()) || call->op.same_as(STMatrixOp())) {
...@@ -63,12 +63,12 @@ class ProxyMarker : public StmtVisitor { ...@@ -63,12 +63,12 @@ class ProxyMarker : public StmtVisitor {
SetProxy(op, proxy); SetProxy(op, proxy);
} }
void VisitStmt_(const BufferStoreNode* op) final { void VisitStmt_(const BufferStoreNode *op) final {
Proxy proxy = Proxy::kGeneric; Proxy proxy = Proxy::kGeneric;
SetProxy(op, proxy); SetProxy(op, proxy);
} }
void VisitStmt_(const SeqStmtNode* op) final { void VisitStmt_(const SeqStmtNode *op) final {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
auto role = GetProxy(op->seq[0]); auto role = GetProxy(op->seq[0]);
for (auto stmt : op->seq) { for (auto stmt : op->seq) {
...@@ -80,61 +80,59 @@ class ProxyMarker : public StmtVisitor { ...@@ -80,61 +80,59 @@ class ProxyMarker : public StmtVisitor {
SetProxy(op, role); SetProxy(op, role);
} }
void VisitStmt_(const IfThenElseNode* op) final { void VisitStmt_(const IfThenElseNode *op) final {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
auto role = GetProxy(op->then_case); auto role = GetProxy(op->then_case);
if (op->else_case.defined()) { if (op->else_case.defined()) {
auto role_else = GetProxy(op->else_case.value()); auto role_else = GetProxy(op->else_case.value());
if (role != role_else) role = Proxy::kBoth; if (role != role_else)
role = Proxy::kBoth;
} }
SetProxy(op, role); SetProxy(op, role);
} }
void VisitStmt_(const BlockRealizeNode* op) final { void VisitStmt_(const BlockRealizeNode *op) final {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
SetProxy(op, GetProxy(op->block)); SetProxy(op, GetProxy(op->block));
} }
template <class NodeType> template <class NodeType> void HandleBodyStmt(const NodeType *op) {
void HandleBodyStmt(const NodeType* op) {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
SetProxy(op, GetProxy(op->body)); SetProxy(op, GetProxy(op->body));
} }
void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
private:
private: void SetProxy(const StmtNode *stmt, Proxy proxy) { map_[stmt] = proxy; }
void SetProxy(const StmtNode* stmt, Proxy proxy) { map_[stmt] = proxy; } std::unordered_map<const StmtNode *, Proxy> map_;
std::unordered_map<const StmtNode*, Proxy> map_;
}; };
class InjectFenceProxy : public StmtExprMutator { class InjectFenceProxy : public StmtExprMutator {
public: public:
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Substitute(PrimFunc f) {
auto T = InjectFenceProxy(); auto T = InjectFenceProxy();
f.CopyOnWrite()->body = T(f->body); f.CopyOnWrite()->body = T(f->body);
return f; return f;
} }
private: private:
Proxy get_generic_proxy(const Stmt& stmt) { Proxy get_generic_proxy(const Stmt &stmt) {
auto marker = ProxyMarker(); auto marker = ProxyMarker();
marker(stmt); marker(stmt);
return marker.GetProxy(stmt); return marker.GetProxy(stmt);
} }
Stmt VisitStmt_(const SeqStmtNode* op) final { Stmt VisitStmt_(const SeqStmtNode *op) final {
ICHECK(op->seq.size() > 0); ICHECK(op->seq.size() > 0);
Array<Stmt> new_body; Array<Stmt> new_body;
Proxy cur_proxy, prev_proxy; Proxy cur_proxy, prev_proxy;
auto fence_stmt = Evaluate(Call(DataType::Handle(), FenceProxyAsyncOp(), {})); auto fence_stmt =
Evaluate(Call(DataType::Handle(), FenceProxyAsyncOp(), {}));
prev_proxy = get_generic_proxy(op->seq[0]); prev_proxy = get_generic_proxy(op->seq[0]);
new_body.push_back(VisitStmt(op->seq[0])); new_body.push_back(VisitStmt(op->seq[0]));
if (op->seq.size() > 1) { if (op->seq.size() > 1) {
...@@ -171,5 +169,5 @@ tvm::transform::Pass InjectFenceProxy() { ...@@ -171,5 +169,5 @@ tvm::transform::Pass InjectFenceProxy() {
TVM_REGISTER_GLOBAL("tl.transform.InjectFenceProxy") TVM_REGISTER_GLOBAL("tl.transform.InjectFenceProxy")
.set_body_typed(InjectFenceProxy); .set_body_typed(InjectFenceProxy);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
/*! /*!
* \file inject_software_pipeline.cc * \file inject_software_pipeline.cc
* \brief Transform annotated loops into pipelined one that parallelize producers and consumers * \brief Transform annotated loops into pipelined one that parallelize
* producers and consumers
*/ */
#include <tvm/target/target.h> #include <tvm/target/target.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
...@@ -38,24 +39,27 @@ using namespace tir; ...@@ -38,24 +39,27 @@ using namespace tir;
/*! /*!
* \brief Create a block and infer the access region with the given body. * \brief Create a block and infer the access region with the given body.
* *
* The result is a opaque block that doesn't contain any block iter vars. In case the body is a * The result is a opaque block that doesn't contain any block iter vars. In
* block realize without predicate, it is unnecessary to create a new block, the block of the block * case the body is a block realize without predicate, it is unnecessary to
* realize will be returned. * create a new block, the block of the block realize will be returned.
* *
* \param body The body of the block. * \param body The body of the block.
* \param buffer_data_to_buffer The map from buffer data to buffer. * \param buffer_data_to_buffer The map from buffer data to buffer.
* \return The result block. * \return The result block.
*/ */
Block MakeBlock(const Stmt& body, const Map<Var, Buffer>& buffer_data_to_buffer) { Block MakeBlock(const Stmt &body,
if (const BlockRealizeNode* block_realize = body.as<BlockRealizeNode>()) { const Map<Var, Buffer> &buffer_data_to_buffer) {
if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
if (is_one(block_realize->predicate)) { if (is_one(block_realize->predicate)) {
// no need to create a new block // no need to create a new block
return block_realize->block; return block_realize->block;
} }
} }
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ body); Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); /*body*/ body);
BlockNode* n = block.CopyOnWrite(); Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer);
BlockNode *n = block.CopyOnWrite();
n->reads = access[0]; n->reads = access[0];
n->writes = access[1]; n->writes = access[1];
return block; return block;
...@@ -68,69 +72,76 @@ struct PipelineAnnotation { ...@@ -68,69 +72,76 @@ struct PipelineAnnotation {
bool async; bool async;
}; };
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation, ObjectPtrHash, ObjectPtrEqual>; using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
ObjectPtrHash, ObjectPtrEqual>;
struct BufferAccessInfo { struct BufferAccessInfo {
int def = -1; // the defining stage of the buffer int def = -1; // the defining stage of the buffer
int use = -1; // the last using stage of the buffer int use = -1; // the last using stage of the buffer
}; };
/*! /*!
* \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices * \brief Rewriter for the body of the software pipeline. This pass inserts
* of the remapped buffer to select the version corresponding to the pipeline stage. * `floormod` to indices of the remapped buffer to select the version
* corresponding to the pipeline stage.
*/ */
class PipelineBodyRewriter : public StmtExprMutator { class PipelineBodyRewriter : public StmtExprMutator {
public: public:
/*! /*!
* \brief Constructor of PipelineBodyRewriter. * \brief Constructor of PipelineBodyRewriter.
* \param buffer_data_to_buffer The map from buffer data to buffer. * \param buffer_data_to_buffer The map from buffer data to buffer.
* \param buffer_remap The map from original buffer to the buffer with updated shape for * \param buffer_remap The map from original buffer to the buffer with updated
* multi-versioning in the software pipeline. * shape for multi-versioning in the software pipeline. \param pipeline_loop
* \param pipeline_loop The original loop to be software pipelined. * The original loop to be software pipelined. \param access_all_versions
* \param access_all_versions Whether all versions the buffers in the software pipeline are * Whether all versions the buffers in the software pipeline are accessed.
* accessed. This will be used to update block access region. In the prologue and epilogue * This will be used to update block access region. In the prologue and
* of a two-stage software pipeline, only one version of these buffers are accessed. * epilogue of a two-stage software pipeline, only one version of these
* buffers are accessed.
*/ */
PipelineBodyRewriter(const Map<Var, Buffer>& buffer_data_to_buffer, PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
const Map<Buffer, Buffer>& buffer_remap, For pipeline_loop, const Map<Buffer, Buffer> &buffer_remap,
bool access_all_versions) For pipeline_loop, bool access_all_versions)
: buffer_data_to_buffer_(buffer_data_to_buffer), : buffer_data_to_buffer_(buffer_data_to_buffer),
buffer_remap_(buffer_remap), buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop),
pipeline_loop_(pipeline_loop),
access_all_versions_(access_all_versions) {} access_all_versions_(access_all_versions) {}
private: private:
BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const { BufferRegion
RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
auto it = buffer_remap_.find(buffer_region->buffer); auto it = buffer_remap_.find(buffer_region->buffer);
if (it != buffer_remap_.end()) { if (it != buffer_remap_.end()) {
Region new_region = buffer_region->region; Region new_region = buffer_region->region;
const Buffer& new_buffer = (*it).second; const Buffer &new_buffer = (*it).second;
// For pipeline buffers, relax the access region of the first dimension to full extent // For pipeline buffers, relax the access region of the first dimension to
// if access_all_versions == true // full extent if access_all_versions == true
Range accessed_version = Range accessed_version =
access_all_versions_ access_all_versions_
? Range::FromMinExtent(0, new_buffer->shape[0]) ? Range::FromMinExtent(0, new_buffer->shape[0])
: Range::FromMinExtent(floormod((pipeline_loop_->loop_var - pipeline_loop_->min), : Range::FromMinExtent(
new_buffer->shape[0]), floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
Integer(1)); new_buffer->shape[0]),
Integer(1));
new_region.insert(new_region.begin(), accessed_version); new_region.insert(new_region.begin(), accessed_version);
return BufferRegion(new_buffer, new_region); return BufferRegion(new_buffer, new_region);
} }
return buffer_region; return buffer_region;
} }
PrimExpr RewriteBufferAccess(const Call& call, const std::vector<int> arg_indices) { PrimExpr RewriteBufferAccess(const Call &call,
auto product = [](const Array<PrimExpr>& input) { const std::vector<int> arg_indices) {
return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, auto product = [](const Array<PrimExpr> &input) {
make_const(DataType::Int(32), 1), input); return foldl(
[](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
}; };
Array<PrimExpr> new_args = call->args; Array<PrimExpr> new_args = call->args;
for (int i : arg_indices) { for (int i : arg_indices) {
const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[i])); const Buffer &buffer =
buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
auto it = buffer_remap_.find(buffer); auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) { if (it != buffer_remap_.end()) {
const Buffer& new_buffer = (*it).second; const Buffer &new_buffer = (*it).second;
const PrimExpr& old_index = call->args[i + 1]; const PrimExpr &old_index = call->args[i + 1];
PrimExpr offset; PrimExpr offset;
if (new_buffer->strides.empty()) { if (new_buffer->strides.empty()) {
offset = product(buffer->shape); offset = product(buffer->shape);
...@@ -138,62 +149,63 @@ class PipelineBodyRewriter : public StmtExprMutator { ...@@ -138,62 +149,63 @@ class PipelineBodyRewriter : public StmtExprMutator {
offset = new_buffer->strides[0]; offset = new_buffer->strides[0];
} }
PrimExpr new_index = PrimExpr new_index =
old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset; old_index +
floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
new_args.Set(i + 1, new_index); new_args.Set(i + 1, new_index);
} }
} }
return Call(call->dtype, call->op, new_args, call->span); return Call(call->dtype, call->op, new_args, call->span);
} }
Stmt VisitStmt_(const BlockNode* op) final { Stmt VisitStmt_(const BlockNode *op) final {
for (const Buffer& alloc_buffer : op->alloc_buffers) { for (const Buffer &alloc_buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
} }
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
BlockNode* n = block.CopyOnWrite(); BlockNode *n = block.CopyOnWrite();
n->reads.MutateByApply([this](const BufferRegion& buffer_region) { n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
return RewritePipelineBufferRegion(buffer_region); return RewritePipelineBufferRegion(buffer_region);
}); });
n->writes.MutateByApply([this](const BufferRegion& buffer_region) { n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
return RewritePipelineBufferRegion(buffer_region); return RewritePipelineBufferRegion(buffer_region);
}); });
for (const Buffer& alloc_buffer : op->alloc_buffers) { for (const Buffer &alloc_buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(alloc_buffer->data); buffer_data_to_buffer_.erase(alloc_buffer->data);
} }
return std::move(block); return std::move(block);
} }
Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_remap_.find(store->buffer); auto it = buffer_remap_.find(store->buffer);
if (it == buffer_remap_.end()) { if (it == buffer_remap_.end()) {
return std::move(store); return std::move(store);
} }
const Buffer& new_buffer = (*it).second; const Buffer &new_buffer = (*it).second;
auto* n = store.CopyOnWrite(); auto *n = store.CopyOnWrite();
n->buffer = new_buffer; n->buffer = new_buffer;
PrimExpr version = PrimExpr version = floormod(
floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
n->indices.insert(n->indices.begin(), version); n->indices.insert(n->indices.begin(), version);
return std::move(store); return std::move(store);
} }
PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_remap_.find(load->buffer); auto it = buffer_remap_.find(load->buffer);
if (it == buffer_remap_.end()) { if (it == buffer_remap_.end()) {
return std::move(load); return std::move(load);
} }
const Buffer& new_buffer = (*it).second; const Buffer &new_buffer = (*it).second;
auto* n = load.CopyOnWrite(); auto *n = load.CopyOnWrite();
n->buffer = new_buffer; n->buffer = new_buffer;
PrimExpr version = PrimExpr version = floormod(
floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
n->indices.insert(n->indices.begin(), version); n->indices.insert(n->indices.begin(), version);
return std::move(load); return std::move(load);
} }
PrimExpr VisitExpr_(const CallNode* op) final { PrimExpr VisitExpr_(const CallNode *op) final {
Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(builtin::tvm_access_ptr())) { if (call->op.same_as(builtin::tvm_access_ptr())) {
return RewriteBufferAccess(call, {1}); return RewriteBufferAccess(call, {1});
...@@ -208,24 +220,25 @@ class PipelineBodyRewriter : public StmtExprMutator { ...@@ -208,24 +220,25 @@ class PipelineBodyRewriter : public StmtExprMutator {
}; };
/*! /*!
* \brief Rewriter for the software pipeline that rewrite a loop into a pipelined one. * \brief Rewriter for the software pipeline that rewrite a loop into a
* pipelined one.
*/ */
class PipelineRewriter : public StmtExprMutator { class PipelineRewriter : public StmtExprMutator {
public: public:
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer, const Array<Buffer>& pipeline_allocs, PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
const For& pipeline_loop, const PipelineInfo& pipeline_info) const Array<Buffer> &pipeline_allocs,
const For &pipeline_loop, const PipelineInfo &pipeline_info)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs), pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info) {} pipeline_info_(pipeline_info) {}
Stmt BuildPipeline() { Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions // Step 1: Analyze accesses to the buffers in the pipeline and compute the
// need to maintain for each buffer. // number of versions need to maintain for each buffer.
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos = std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
GetBufferAccessInfo(); infos = GetBufferAccessInfo();
for (const Buffer& buffer : pipeline_allocs_) { for (const Buffer &buffer : pipeline_allocs_) {
int num_versions = ComputeBufferVersions(buffer, infos.at(buffer)); int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
if (num_versions > 1) { if (num_versions > 1) {
buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions)); buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
...@@ -233,27 +246,28 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -233,27 +246,28 @@ class PipelineRewriter : public StmtExprMutator {
} }
ordered_stmts_.resize(pipeline_info_.size()); ordered_stmts_.resize(pipeline_info_.size());
for (const auto& [block, anno] : pipeline_info_) { for (const auto &[block, anno] : pipeline_info_) {
ordered_stmts_.Set(anno.order, block); ordered_stmts_.Set(anno.order, block);
} }
for (const Block& block : ordered_stmts_) { for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_[block].stage; int stage = pipeline_info_[block].stage;
if (pipeline_info_[block].async) { if (pipeline_info_[block].async) {
auto& state = async_states[stage]; auto &state = async_states[stage];
state.producer_head = pipeline_loop_->min - 1; state.producer_head = pipeline_loop_->min - 1;
for (auto write_region : block->writes) { for (auto write_region : block->writes) {
auto buffer = write_region->buffer; auto buffer = write_region->buffer;
state.dst_buffers.insert(buffer.get()); state.dst_buffers.insert(buffer.get());
if (buffer_remap_.count(buffer)) state.dst_buffers.insert(buffer_remap_[buffer].get()); if (buffer_remap_.count(buffer))
state.dst_buffers.insert(buffer_remap_[buffer].get());
} }
} }
} }
std::unordered_set<int> consumed; std::unordered_set<int> consumed;
for (const Block& block : ordered_stmts_) { for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_[block].stage; int stage = pipeline_info_[block].stage;
if (pipeline_info_[block].async) { if (pipeline_info_[block].async) {
auto& state = async_states[stage]; auto &state = async_states[stage];
if (state.commit_groups.empty() || consumed.count(stage)) { if (state.commit_groups.empty() || consumed.count(stage)) {
state.commit_groups.push_back({}); state.commit_groups.push_back({});
} }
...@@ -263,13 +277,15 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -263,13 +277,15 @@ class PipelineRewriter : public StmtExprMutator {
auto buffer = buffer_remap_.count(write_region->buffer) auto buffer = buffer_remap_.count(write_region->buffer)
? buffer_remap_[write_region->buffer] ? buffer_remap_[write_region->buffer]
: write_region->buffer; : write_region->buffer;
state.buffer_to_commit_group_[buffer.get()] = state.commit_groups.size() - 1; state.buffer_to_commit_group_[buffer.get()] =
state.commit_groups.size() - 1;
} }
} }
for (auto read_region : block->reads) { for (auto read_region : block->reads) {
for (const auto& [producer_stage_id, producer_state] : async_states) { for (const auto &[producer_stage_id, producer_state] : async_states) {
if (producer_stage_id <= stage && producer_state.writes(read_region->buffer)) { if (producer_stage_id <= stage &&
producer_state.writes(read_region->buffer)) {
consumed.insert(producer_stage_id); consumed.insert(producer_stage_id);
} }
} }
...@@ -277,17 +293,21 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -277,17 +293,21 @@ class PipelineRewriter : public StmtExprMutator {
} }
// Step 2: Emit the pipeline prologue, body and epilogue. // Step 2: Emit the pipeline prologue, body and epilogue.
Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true, true); Stmt prologue = EmitImpl(pipeline_loop_->min,
Stmt body = EmitImpl(pipeline_loop_->min + max_stage_, pipeline_loop_->min + max_stage_, true, true);
pipeline_loop_->min + pipeline_loop_->extent, false, false); Stmt body =
Stmt epilogue = EmitImpl(pipeline_loop_->min + pipeline_loop_->extent, EmitImpl(pipeline_loop_->min + max_stage_,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true); pipeline_loop_->min + pipeline_loop_->extent, false, false);
Stmt epilogue = EmitImpl(
pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
SeqStmt stmt = SeqStmt({prologue, body, epilogue}); SeqStmt stmt = SeqStmt({prologue, body, epilogue});
// Step 3: Make a new block that contains new buffer allocations after pipeline rewriting. // Step 3: Make a new block that contains new buffer allocations after
// pipeline rewriting.
Array<Buffer> alloc_buffers; Array<Buffer> alloc_buffers;
for (const auto& alloc : pipeline_allocs_) { for (const auto &alloc : pipeline_allocs_) {
alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc)); alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
buffer_data_to_buffer_.erase(alloc->data); buffer_data_to_buffer_.erase(alloc->data);
} }
...@@ -296,26 +316,28 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -296,26 +316,28 @@ class PipelineRewriter : public StmtExprMutator {
return BlockRealize({}, Bool(true), block); return BlockRealize({}, Bool(true), block);
} }
private: private:
/*! /*!
* \brief Analyze accesses to the buffers in the software pipeline. * \brief Analyze accesses to the buffers in the software pipeline.
* *
* This method check the 'define' and 'use' stage of the buffers in the software pipeline, which * This method check the 'define' and 'use' stage of the buffers in the
* can be used to compute the number of versions needed to maintain after rewriting. * software pipeline, which can be used to compute the number of versions
* needed to maintain after rewriting.
*/ */
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
GetBufferAccessInfo() { GetBufferAccessInfo() {
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos; std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
for (const auto& pair : pipeline_info_) { infos;
const Block& block = pair.first; for (const auto &pair : pipeline_info_) {
const Block &block = pair.first;
int stage = pair.second.stage; int stage = pair.second.stage;
max_stage_ = std::max(max_stage_, stage); max_stage_ = std::max(max_stage_, stage);
for (const BufferRegion& write : block->writes) { for (const BufferRegion &write : block->writes) {
if (!infos.count(write->buffer)) { if (!infos.count(write->buffer)) {
infos.emplace(write->buffer, BufferAccessInfo{}); infos.emplace(write->buffer, BufferAccessInfo{});
} }
auto& info = infos.at(write->buffer); auto &info = infos.at(write->buffer);
if (info.def == -1) { if (info.def == -1) {
info.def = stage; info.def = stage;
} else { } else {
...@@ -323,11 +345,11 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -323,11 +345,11 @@ class PipelineRewriter : public StmtExprMutator {
} }
} }
for (const BufferRegion& read : block->reads) { for (const BufferRegion &read : block->reads) {
if (!infos.count(read->buffer)) { if (!infos.count(read->buffer)) {
infos.emplace(read->buffer, BufferAccessInfo{}); infos.emplace(read->buffer, BufferAccessInfo{});
} }
auto& info = infos.at(read->buffer); auto &info = infos.at(read->buffer);
info.use = std::max(info.use, stage); info.use = std::max(info.use, stage);
} }
} }
...@@ -355,58 +377,64 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -355,58 +377,64 @@ class PipelineRewriter : public StmtExprMutator {
} }
/*! /*!
* \brief Compute the number of versions need to maintain for buffer accessed in the software * \brief Compute the number of versions need to maintain for buffer accessed
* pipeline. * in the software pipeline.
* *
* This method applies liveness analysis to the target buffer to compute the number of versions * This method applies liveness analysis to the target buffer to compute the
* need to maintain during the software pipeline. * number of versions need to maintain during the software pipeline.
* Annotation `attr::double_buffer_scope` is handled here which provides a way to override the * Annotation `attr::double_buffer_scope` is handled here which provides a way
* result of the analysis. Additional double buffering in the software pipeline can be useful * to override the result of the analysis. Additional double buffering in the
* to eliminate synchronizations in GPU devices. * software pipeline can be useful to eliminate synchronizations in GPU
* devices.
* *
* \param buffer The target buffer * \param buffer The target buffer
* \param buffer_info The access information of the target buffer. * \param buffer_info The access information of the target buffer.
* \return The number of versions required for the target buffer. * \return The number of versions required for the target buffer.
*/ */
int ComputeBufferVersions(const Buffer& buffer, const BufferAccessInfo& buffer_info) { int ComputeBufferVersions(const Buffer &buffer,
const BufferAccessInfo &buffer_info) {
if (buffer_info.def == -1) { if (buffer_info.def == -1) {
// Keep the original number of versions as buffers defined outside the software pipeline // Keep the original number of versions as buffers defined outside the
// should not be mutated. // software pipeline should not be mutated.
return 1; return 1;
} }
// `use - def + 1` is a upper bound of the needed versions // `use - def + 1` is a upper bound of the needed versions
// We optimize a few case where the number of versions can be smaller than the upper bound // We optimize a few case where the number of versions can be smaller than
// the upper bound
int num_versions = buffer_info.use - buffer_info.def + 1; int num_versions = buffer_info.use - buffer_info.def + 1;
if (num_versions >= 2) { if (num_versions >= 2) {
// A special case when `use - def + 1 == 2`. Double buffering is only needed in this case when // A special case when `use - def + 1 == 2`. Double buffering is only
// these exists a reader block_i and a writer block_j such that // needed in this case when these exists a reader block_i and a writer
// order(block_i) < order(block_j) and stage(block_i) < stage(block_j) and the access regions // block_j such that order(block_i) < order(block_j) and stage(block_i) <
// of block_i and block_j overlap. // stage(block_j) and the access regions of block_i and block_j overlap.
bool need_multi_version = false; bool need_multi_version = false;
for (const auto& pair1 : pipeline_info_) { for (const auto &pair1 : pipeline_info_) {
const Block& writer_block = pair1.first; const Block &writer_block = pair1.first;
const auto& writer_info = pair1.second; const auto &writer_info = pair1.second;
auto it1 = std::find_if(writer_block->writes.begin(), writer_block->writes.end(), auto it1 = std::find_if(writer_block->writes.begin(),
[&](const BufferRegion& buffer_region) { writer_block->writes.end(),
[&](const BufferRegion &buffer_region) {
return buffer_region->buffer.same_as(buffer); return buffer_region->buffer.same_as(buffer);
}); });
if (it1 == writer_block->writes.end()) { if (it1 == writer_block->writes.end()) {
continue; continue;
} }
for (const auto& pair2 : pipeline_info_) { for (const auto &pair2 : pipeline_info_) {
const Block& reader_block = pair2.first; const Block &reader_block = pair2.first;
const auto& reader_info = pair2.second; const auto &reader_info = pair2.second;
auto it2 = std::find_if(reader_block->reads.begin(), reader_block->reads.end(), auto it2 = std::find_if(
[&](const BufferRegion& buffer_region) { reader_block->reads.begin(), reader_block->reads.end(),
return buffer_region->buffer.same_as(buffer); [&](const BufferRegion &buffer_region) {
}); return buffer_region->buffer.same_as(buffer);
});
if (it2 == reader_block->reads.end()) { if (it2 == reader_block->reads.end()) {
continue; continue;
} }
if (writer_info.order < reader_info.order && writer_info.stage < reader_info.stage && if (writer_info.order < reader_info.order &&
writer_info.stage < reader_info.stage &&
MayConflict((*it1)->region, (*it2)->region)) { MayConflict((*it1)->region, (*it2)->region)) {
need_multi_version = true; need_multi_version = true;
break; break;
...@@ -421,13 +449,12 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -421,13 +449,12 @@ class PipelineRewriter : public StmtExprMutator {
} }
/*! /*!
* \brief Rewrite buffer allocation to keep multiple versions of original buffer for pipelined * \brief Rewrite buffer allocation to keep multiple versions of original
* accesses. * buffer for pipelined accesses. \param buffer The buffer to be resized.
* \param buffer The buffer to be resized.
* \param num_versions The number of versions to keep. * \param num_versions The number of versions to keep.
* \return The resized buffer. * \return The resized buffer.
*/ */
Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get())); ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (new_buffer->strides.size()) { if (new_buffer->strides.size()) {
...@@ -438,29 +465,32 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -438,29 +465,32 @@ class PipelineRewriter : public StmtExprMutator {
return Buffer(new_buffer); return Buffer(new_buffer);
} }
// Per-stage states that need to be tracked across pipeline prologue, body, and epilogue. // Per-stage states that need to be tracked across pipeline prologue, body,
// and epilogue.
struct AsyncStateGlobal { struct AsyncStateGlobal {
// Buffers that this stage asynchronously writes. // Buffers that this stage asynchronously writes.
std::unordered_set<const BufferNode*> dst_buffers; std::unordered_set<const BufferNode *> dst_buffers;
// An imaginary index that the latest async operation associated with this stage has written // An imaginary index that the latest async operation associated with this
// into. Only valid if all associated predicates are true, so that we can count the number of // stage has written into. Only valid if all associated predicates are true,
// async invocations exactly. When it is valid, it is the "sum of extents of loops that have // so that we can count the number of async invocations exactly. When it is
// been executed" - 1, e.g. for epilogue it is prologue extent + body extent - 1. This // valid, it is the "sum of extents of loops that have been executed" - 1,
// is only needed to compute wait count for epilogue without async producers. // e.g. for epilogue it is prologue extent + body extent - 1. This is only
// needed to compute wait count for epilogue without async producers.
PrimExpr producer_head; PrimExpr producer_head;
std::vector<std::vector<int>> commit_groups; std::vector<std::vector<int>> commit_groups;
std::unordered_map<const BufferNode*, int> buffer_to_commit_group_; std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; } bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
}; };
// Per-stage states that are local to each of pipeline prologue, body, and epilogue. // Per-stage states that are local to each of pipeline prologue, body, and
// epilogue.
struct AsyncStateLocal { struct AsyncStateLocal {
struct PendingWait { struct PendingWait {
// The index into a list of blocks, where async_wait_queue should be attached at the // The index into a list of blocks, where async_wait_queue should be
// beginning. // attached at the beginning.
int insert_before; int insert_before;
// in_flight_count would be a more precise name, but the implementation uses wait_count for // in_flight_count would be a more precise name, but the implementation
// brevity. // uses wait_count for brevity.
PrimExpr wait_count{nullptr}; PrimExpr wait_count{nullptr};
bool valid() const { return wait_count.defined(); } bool valid() const { return wait_count.defined(); }
...@@ -468,8 +498,8 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -468,8 +498,8 @@ class PipelineRewriter : public StmtExprMutator {
std::vector<PendingWait> pending_waits; std::vector<PendingWait> pending_waits;
// A symbolic expression representing the index the latest async operation associated with this // A symbolic expression representing the index the latest async operation
// stage has written into, at the "current" iteration. // associated with this stage has written into, at the "current" iteration.
Optional<PrimExpr> producer_head; Optional<PrimExpr> producer_head;
}; };
...@@ -483,31 +513,35 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -483,31 +513,35 @@ class PipelineRewriter : public StmtExprMutator {
bool is_async; bool is_async;
}; };
void PopulateWaitCounts(const std::vector<RewrittenBlockInfo>& new_blocks, void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
std::map<int, AsyncStateLocal>* async_states_local) { std::map<int, AsyncStateLocal> *async_states_local) {
for (size_t i = 0; i < new_blocks.size(); ++i) { for (size_t i = 0; i < new_blocks.size(); ++i) {
int producer_stage_idx = -1; int producer_stage_idx = -1;
for (auto read_region : new_blocks[i].block->reads) { for (auto read_region : new_blocks[i].block->reads) {
for (const auto& [stage, state] : async_states) { for (const auto &[stage, state] : async_states) {
if (stage <= new_blocks[i].stage && state.writes(read_region->buffer)) { if (stage <= new_blocks[i].stage &&
// Found an earlier stage where read_region->buffer was asynchronously written state.writes(read_region->buffer)) {
// Found an earlier stage where read_region->buffer was
// asynchronously written
ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage) ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
<< "A dependency on multiple async stages is not supported"; << "A dependency on multiple async stages is not supported";
producer_stage_idx = stage; producer_stage_idx = stage;
} }
} }
} }
if (producer_stage_idx == -1) continue; if (producer_stage_idx == -1)
const auto& state = async_states[producer_stage_idx]; continue;
auto& dep_local_state = (*async_states_local)[producer_stage_idx]; const auto &state = async_states[producer_stage_idx];
auto &dep_local_state = (*async_states_local)[producer_stage_idx];
PrimExpr in_flight_cnt = 0; PrimExpr in_flight_cnt = 0;
for (const auto& group : state.commit_groups) { for (const auto &group : state.commit_groups) {
PrimExpr consumer_head = new_blocks[i].access_index; PrimExpr consumer_head = new_blocks[i].access_index;
PrimExpr producer_head; PrimExpr producer_head;
if (dep_local_state.producer_head.defined()) { if (dep_local_state.producer_head.defined()) {
producer_head = dep_local_state.producer_head.value(); producer_head = dep_local_state.producer_head.value();
// if the group is after the wait point, minus by 1 // if the group is after the wait point, minus by 1
if (group.front() > new_blocks[i].order) producer_head -= 1; if (group.front() > new_blocks[i].order)
producer_head -= 1;
} else { } else {
producer_head = state.producer_head; producer_head = state.producer_head;
} }
...@@ -516,41 +550,43 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -516,41 +550,43 @@ class PipelineRewriter : public StmtExprMutator {
// We can relax the in-flight-count by the number of independent commit. // We can relax the in-flight-count by the number of independent commit.
std::unordered_set<int> dependent_groups; std::unordered_set<int> dependent_groups;
for (const auto& read_region : new_blocks[i].block->reads) { for (const auto &read_region : new_blocks[i].block->reads) {
if (state.buffer_to_commit_group_.count(read_region->buffer.get())) if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
dependent_groups.insert(state.buffer_to_commit_group_.at(read_region->buffer.get())); dependent_groups.insert(
state.buffer_to_commit_group_.at(read_region->buffer.get()));
} }
for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) { for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
if (dependent_groups.count(i) == 0) if (dependent_groups.count(i) == 0)
in_flight_cnt += 1; in_flight_cnt += 1;
else else
break; // stop relaxing break; // stop relaxing
} }
in_flight_cnt = analyzer_.Simplify(in_flight_cnt); in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
dep_local_state.pending_waits.push_back({static_cast<int>(i), in_flight_cnt}); dep_local_state.pending_waits.push_back(
{static_cast<int>(i), in_flight_cnt});
} }
} }
// Given pipelined blocks and async-related information, generate final loop statements with async // Given pipelined blocks and async-related information, generate final loop
// scopes (if any). // statements with async scopes (if any).
Array<Stmt> CompletePipelineLoopStatements( Array<Stmt> CompletePipelineLoopStatements(
const std::vector<RewrittenBlockInfo>& blocks, const std::vector<RewrittenBlockInfo> &blocks,
const std::map<int, AsyncStateLocal>& async_states_local) const { const std::map<int, AsyncStateLocal> &async_states_local) const {
std::vector<RewrittenBlockInfo> new_blocks = blocks; std::vector<RewrittenBlockInfo> new_blocks = blocks;
for (const auto& [stage_id, state] : async_states_local) { for (const auto &[stage_id, state] : async_states_local) {
for (const auto& pw : state.pending_waits) { for (const auto &pw : state.pending_waits) {
auto& block = new_blocks[pw.insert_before].block; auto &block = new_blocks[pw.insert_before].block;
BlockNode* n = block.CopyOnWrite(); BlockNode *n = block.CopyOnWrite();
auto zero = make_zero(DataType::Int(32)); auto zero = make_zero(DataType::Int(32));
n->body = n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, AttrStmt(zero, tir::attr::async_wait_inflight_count,
AttrStmt(zero, tir::attr::async_wait_inflight_count, pw.wait_count, n->body)); pw.wait_count, n->body));
} }
} }
// mark the last async stmt as commit // mark the last async stmt as commit
std::unordered_set<int> commit_group_indices; std::unordered_set<int> commit_group_indices;
for (const auto& [stage_id, state] : async_states) { for (const auto &[stage_id, state] : async_states) {
for (size_t i = 0; i < state.commit_groups.size(); ++i) { for (size_t i = 0; i < state.commit_groups.size(); ++i) {
commit_group_indices.insert(state.commit_groups[i].back()); commit_group_indices.insert(state.commit_groups[i].back());
} }
...@@ -561,9 +597,9 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -561,9 +597,9 @@ class PipelineRewriter : public StmtExprMutator {
for (size_t i = 0; i < new_blocks.size(); i++) { for (size_t i = 0; i < new_blocks.size(); i++) {
Block block = new_blocks[i].block; Block block = new_blocks[i].block;
if (commit_group_indices.count(new_blocks[i].order)) { if (commit_group_indices.count(new_blocks[i].order)) {
auto commit_queue_scope = auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_commit_queue_scope, tir::attr::async_commit_queue_scope,
new_blocks[i].stage, block->body); new_blocks[i].stage, block->body);
block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
} }
stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block)); stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
...@@ -579,15 +615,18 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -579,15 +615,18 @@ class PipelineRewriter : public StmtExprMutator {
* \param unroll_loop Whether the loop should be unrolled. * \param unroll_loop Whether the loop should be unrolled.
* \return The result loop. * \return The result loop.
*/ */
Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop, bool need_bound_check) { Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop,
bool need_bound_check) {
PrimExpr new_loop_var; PrimExpr new_loop_var;
PrimExpr extent = end - start; PrimExpr extent = end - start;
auto make_nop = []() { return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); }; auto make_nop = []() {
return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
};
bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
if (is_unit_loop) { if (is_unit_loop) {
new_loop_var = start; // use constants as the loop var for unit loops new_loop_var = start; // use constants as the loop var for unit loops
} else { } else {
new_loop_var = pipeline_loop_->loop_var.copy_with_suffix(""); new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end)); analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
...@@ -598,45 +637,52 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -598,45 +637,52 @@ class PipelineRewriter : public StmtExprMutator {
// Async related // Async related
std::map<int, AsyncStateLocal> async_states_local; std::map<int, AsyncStateLocal> async_states_local;
for (const Block& block : ordered_stmts_) { for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_.at(block).stage; int stage = pipeline_info_.at(block).stage;
int order = pipeline_info_.at(block).order; int order = pipeline_info_.at(block).order;
PrimExpr inbound = Bool(true); PrimExpr inbound = Bool(true);
PrimExpr skewed_loop_var = new_loop_var - stage; PrimExpr skewed_loop_var = new_loop_var - stage;
if (need_bound_check) if (need_bound_check)
inbound = analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && inbound =
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
if (analyzer_.CanProve(!inbound)) { if (analyzer_.CanProve(!inbound)) {
continue; continue;
} }
Block new_block = Downcast<Block>(PipelineBodyRewriter( Block new_block = Downcast<Block>(
buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, max_stage_ != 1)(block)); PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
pipeline_loop_, max_stage_ != 1)(block));
PrimExpr delta = start - pipeline_loop_->min; PrimExpr delta = start - pipeline_loop_->min;
// This variable corresponds to // This variable corresponds to
// - "producer_head" if this stage is an async producer // - "producer_head" if this stage is an async producer
// - "consumer_head" if this stage reads from asynchronously written buffers. // - "consumer_head" if this stage reads from asynchronously written
PrimExpr normalized_access_index = is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; // buffers.
PrimExpr normalized_access_index =
is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
// Adjust the block predicate and the body according to the final loop bound // Adjust the block predicate and the body according to the final loop
// bound
// [pipeline_loop_->min, extent). // [pipeline_loop_->min, extent).
if (!is_unit_loop) { if (!is_unit_loop) {
Var loop_iter = Downcast<Var>(new_loop_var); Var loop_iter = Downcast<Var>(new_loop_var);
inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
} }
new_block = Downcast<Block>( new_block = Downcast<Block>(Substitute(
Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
if (pipeline_info_[block].async) { if (pipeline_info_[block].async) {
auto& local_state = async_states_local[stage]; auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index; local_state.producer_head = normalized_access_index;
BlockNode* n = new_block.CopyOnWrite(); BlockNode *n = new_block.CopyOnWrite();
n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, 1, n->body); n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
1, n->body);
} }
new_blocks.push_back( new_blocks.push_back({stage, order, inbound, new_block,
{stage, order, inbound, new_block, normalized_access_index, pipeline_info_[block].async}); normalized_access_index,
pipeline_info_[block].async});
} }
PopulateWaitCounts(new_blocks, &async_states_local); PopulateWaitCounts(new_blocks, &async_states_local);
...@@ -655,8 +701,8 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -655,8 +701,8 @@ class PipelineRewriter : public StmtExprMutator {
if (!is_unit_loop) { if (!is_unit_loop) {
Map<String, ObjectRef> preserved_annotations; Map<String, ObjectRef> preserved_annotations;
for (const auto& kv : pipeline_loop_->annotations) { for (const auto &kv : pipeline_loop_->annotations) {
const String& key = kv.first; const String &key = kv.first;
if (kv.first != tir::attr::software_pipeline_stage && if (kv.first != tir::attr::software_pipeline_stage &&
kv.first != tir::attr::software_pipeline_order && kv.first != tir::attr::software_pipeline_order &&
kv.first != tir::attr::software_pipeline_async_stages) { kv.first != tir::attr::software_pipeline_async_stages) {
...@@ -664,16 +710,17 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -664,16 +710,17 @@ class PipelineRewriter : public StmtExprMutator {
} }
} }
new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent, new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop), unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
NullOpt, preserved_annotations); std::move(new_loop), NullOpt, preserved_annotations);
} }
// Update producer heads in the global async states. // Update producer heads in the global async states.
for (const auto& [stage_id, state] : async_states_local) { for (const auto &[stage_id, state] : async_states_local) {
async_states[stage_id].producer_head += extent; async_states[stage_id].producer_head += extent;
} }
return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_)); return BlockRealize({}, Bool(true),
MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
} }
arith::Analyzer analyzer_; arith::Analyzer analyzer_;
...@@ -690,22 +737,23 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -690,22 +737,23 @@ class PipelineRewriter : public StmtExprMutator {
/*! /*!
* \brief Build the dependency graph among a array of blocks. * \brief Build the dependency graph among a array of blocks.
* \param[in] blocks The array of blocks. * \param[in] blocks The array of blocks.
* \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the * \param[out] dep_src2dst Optional, a map to store dependency edges from the
* destination. * source to the destination. \param[out] dep_dst2src Optional, a map to store
* \param[out] dep_dst2src Optional, a map to store dependency edges from the * dependency edges from the destination to the source.
* destination to the source.
*/ */
void BuildDependencyGraph( void BuildDependencyGraph(const Array<Block> &blocks,
const Array<Block>& blocks, std::unordered_map<Block, Array<Block>, ObjectPtrHash,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, ObjectPtrEqual> *dep_src2dst,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { std::unordered_map<Block, Array<Block>, ObjectPtrHash,
std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual> buffer_writers; ObjectPtrEqual> *dep_dst2src) {
std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
for (const Block& block : blocks) { buffer_writers;
for (const BufferRegion& read : block->reads) {
for (const Block &block : blocks) {
for (const BufferRegion &read : block->reads) {
auto it = buffer_writers.find(read->buffer->data); auto it = buffer_writers.find(read->buffer->data);
if (it != buffer_writers.end()) { if (it != buffer_writers.end()) {
for (const Block& writer : it->second) { for (const Block &writer : it->second) {
if (dep_src2dst != nullptr) { if (dep_src2dst != nullptr) {
(*dep_src2dst)[writer].push_back(block); (*dep_src2dst)[writer].push_back(block);
} }
...@@ -715,83 +763,89 @@ void BuildDependencyGraph( ...@@ -715,83 +763,89 @@ void BuildDependencyGraph(
} }
} }
} }
for (const BufferRegion& write : block->writes) { for (const BufferRegion &write : block->writes) {
buffer_writers[write->buffer->data].push_back(block); buffer_writers[write->buffer->data].push_back(block);
} }
} }
} }
class PipelineInjector : private StmtExprMutator { class PipelineInjector : private StmtExprMutator {
public: public:
static Stmt Inject(const PrimFunc& func) { static Stmt Inject(const PrimFunc &func) {
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
PipelineInjector injector(global_symbol); PipelineInjector injector(global_symbol);
for (const auto& kv : func->buffer_map) { for (const auto &kv : func->buffer_map) {
const Buffer& buffer = kv.second; const Buffer &buffer = kv.second;
injector.buffer_data_to_buffer_.Set(buffer->data, buffer); injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
return injector(func->body); return injector(func->body);
} }
private: private:
explicit PipelineInjector(Optional<String> global_symbol) : global_symbol_(global_symbol) {} explicit PipelineInjector(Optional<String> global_symbol)
: global_symbol_(global_symbol) {}
/*! /*!
* \brief Check the pipeline satisfies the following conditions: * \brief Check the pipeline satisfies the following conditions:
* 1. No conflicting order: The order of each statement should be unique. * 1. No conflicting order: The order of each statement should be unique.
* 2. Reordering of statements doesn't break buffer access dependencies. Specifically, for * 2. Reordering of statements doesn't break buffer access dependencies.
* dependency (e.g. read-after-write) from statement A to statement B, it requires: * Specifically, for dependency (e.g. read-after-write) from statement A to
* case 1: stage(A) < stage(B) * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
* case 2: stage(A) == stage(B) and order(A) < order(B) * stage(B) and order(A) < order(B)
*/ */
void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array<Block>& original_order) { void ValidatePipelineBody(const PipelineInfo &pipeline_info,
const Array<Block> &original_order) {
std::unordered_set<int> used_orders; std::unordered_set<int> used_orders;
std::unordered_map<int, int> stage_max_order; std::unordered_map<int, int> stage_max_order;
std::unordered_map<int, const Block*> order_to_block; std::unordered_map<int, const Block *> order_to_block;
std::unordered_map<const Block*, int> block_to_stage; std::unordered_map<const Block *, int> block_to_stage;
for (const Block& block : original_order) { for (const Block &block : original_order) {
const auto& stmt_info = pipeline_info.at(block); const auto &stmt_info = pipeline_info.at(block);
int order = stmt_info.order; int order = stmt_info.order;
CHECK(!used_orders.count(order)) CHECK(!used_orders.count(order))
<< "ValueError: Two statements in the software pipeline cannot have the same order"; << "ValueError: Two statements in the software pipeline cannot have "
"the same order";
used_orders.insert(order); used_orders.insert(order);
} }
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual> dep_src2dst; std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
dep_src2dst;
BuildDependencyGraph(original_order, &dep_src2dst, nullptr); BuildDependencyGraph(original_order, &dep_src2dst, nullptr);
for (const auto& pair : dep_src2dst) { for (const auto &pair : dep_src2dst) {
const Block& src = pair.first; const Block &src = pair.first;
const auto& src_info = pipeline_info.at(src); const auto &src_info = pipeline_info.at(src);
const Array<Block>& dsts = pair.second; const Array<Block> &dsts = pair.second;
for (const Block& dst : dsts) { for (const Block &dst : dsts) {
const auto& dst_info = pipeline_info.at(dst); const auto &dst_info = pipeline_info.at(dst);
CHECK_LE(src_info.stage, dst_info.stage) CHECK_LE(src_info.stage, dst_info.stage)
<< "ValueError: statement " << dst << " in stage " << dst_info.stage << "ValueError: statement " << dst << " in stage " << dst_info.stage
<< " cannot depends on statement " << src << " in a later stage " << src_info.stage; << " cannot depends on statement " << src << " in a later stage "
<< src_info.stage;
if (src_info.stage == dst_info.stage) { if (src_info.stage == dst_info.stage) {
CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer " CHECK_LT(src_info.order, dst_info.order)
"access dependency in the same stage of the " << "ValueError: two statements with buffer "
"software pipeline cannot be reordered"; "access dependency in the same stage of the "
"software pipeline cannot be reordered";
} }
} }
} }
} }
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode *op) final {
// Step 1: Recursively rewrite the children first. // Step 1: Recursively rewrite the children first.
For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op)); For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
if (!HasPipelineAnnotation(op)) { if (!HasPipelineAnnotation(op)) {
return std::move(for_node); return std::move(for_node);
} }
// Step 2: Find the body and buffer allocations of the pipeline. The body can be direct child of // Step 2: Find the body and buffer allocations of the pipeline. The body
// the for-loop. If the for-loop has BlockRealize as its child, the pipeline body will be the // can be direct child of the for-loop. If the for-loop has BlockRealize as
// child of the block. // its child, the pipeline body will be the child of the block.
Stmt pipeline_body{nullptr}; Stmt pipeline_body{nullptr};
Array<Buffer> pipeline_allocs; Array<Buffer> pipeline_allocs;
if (const auto* realize = for_node->body.as<BlockRealizeNode>()) { if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
const auto& block = realize->block; const auto &block = realize->block;
for (const auto& buffer : block->alloc_buffers) { for (const auto &buffer : block->alloc_buffers) {
ICHECK(buffer->IsInstance<BufferNode>()); ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
...@@ -801,31 +855,32 @@ class PipelineInjector : private StmtExprMutator { ...@@ -801,31 +855,32 @@ class PipelineInjector : private StmtExprMutator {
pipeline_body = for_node->body; pipeline_body = for_node->body;
} }
const SeqStmtNode* pipeline_body_seq = pipeline_body.as<SeqStmtNode>(); const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
<< "ValueError: The body of the software pipeline should be SeqStmt, got " "should be SeqStmt, got "
<< pipeline_body->GetTypeKey(); << pipeline_body->GetTypeKey();
// Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be // Step 3: Blockize the components of the pipeline. Each child of the
// converted into a block. // pipelined loop will be converted into a block.
PipelineInfo pipeline_info; PipelineInfo pipeline_info;
Array<Block> original_order; // pipeline body blocks in the original order Array<Block> original_order; // pipeline body blocks in the original order
auto f_add_child = [&](const Stmt& child) { auto f_add_child = [&](const Stmt &child) {
original_order.push_back(MakeBlock(child, buffer_data_to_buffer_)); original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
}; };
for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
const auto* nested_block_realize = pipeline_body_seq->seq[i].as<BlockRealizeNode>(); const auto *nested_block_realize =
pipeline_body_seq->seq[i].as<BlockRealizeNode>();
if (nested_block_realize && is_one(nested_block_realize->predicate) && if (nested_block_realize && is_one(nested_block_realize->predicate) &&
nested_block_realize->block->body->IsInstance<SeqStmtNode>()) { nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
const Block& nested_pipeline_block = nested_block_realize->block; const Block &nested_pipeline_block = nested_block_realize->block;
ICHECK( ICHECK(nested_pipeline_block->match_buffers
nested_pipeline_block->match_buffers.empty()); // match_buffer should have been lowered .empty()); // match_buffer should have been lowered
for (const auto& buffer : nested_pipeline_block->alloc_buffers) { for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
pipeline_allocs.push_back(buffer); pipeline_allocs.push_back(buffer);
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
const auto* nested_seq = nested_pipeline_block->body.as<SeqStmtNode>(); const auto *nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
for (size_t j = 0; j < nested_seq->seq.size(); j++) { for (size_t j = 0; j < nested_seq->seq.size(); j++) {
f_add_child(nested_seq->seq[j]); f_add_child(nested_seq->seq[j]);
} }
...@@ -834,21 +889,26 @@ class PipelineInjector : private StmtExprMutator { ...@@ -834,21 +889,26 @@ class PipelineInjector : private StmtExprMutator {
} }
} }
auto pipeline_stages = auto pipeline_stages = Downcast<Array<Integer>>(
Downcast<Array<Integer>>(op->annotations.at(tir::attr::software_pipeline_stage)); op->annotations.at(tir::attr::software_pipeline_stage));
auto pipeline_orders = auto pipeline_orders = Downcast<Array<Integer>>(
Downcast<Array<Integer>>(op->annotations.at(tir::attr::software_pipeline_order)); op->annotations.at(tir::attr::software_pipeline_order));
CHECK_EQ(pipeline_stages.size(), original_order.size()) CHECK_EQ(pipeline_stages.size(), original_order.size())
<< "PrimFunc " << global_symbol_ << " has original order " << "PrimFunc " << global_symbol_ << " has original order "
<< original_order.Map([](const auto& block) { return block->name_hint; }) << original_order.Map(
<< ", but pipeline annotation is " << pipeline_stages << " with different size"; [](const auto &block) { return block->name_hint; })
<< ", but pipeline annotation is " << pipeline_stages
<< " with different size";
CHECK_EQ(pipeline_orders.size(), original_order.size()) CHECK_EQ(pipeline_orders.size(), original_order.size())
<< "PrimFunc " << global_symbol_ << " has original order " << "PrimFunc " << global_symbol_ << " has original order "
<< original_order.Map([](const auto& block) { return block->name_hint; }) << original_order.Map(
<< ", but pipeline annotation is " << pipeline_orders << " with different size"; [](const auto &block) { return block->name_hint; })
<< ", but pipeline annotation is " << pipeline_orders
<< " with different size";
std::unordered_set<int> pipeline_async_stages; std::unordered_set<int> pipeline_async_stages;
if (auto annot = op->annotations.Get(tir::attr::software_pipeline_async_stages)) { if (auto annot =
op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
for (auto s : Downcast<Array<Integer>>(annot)) { for (auto s : Downcast<Array<Integer>>(annot)) {
pipeline_async_stages.insert(s->value); pipeline_async_stages.insert(s->value);
} }
...@@ -856,43 +916,44 @@ class PipelineInjector : private StmtExprMutator { ...@@ -856,43 +916,44 @@ class PipelineInjector : private StmtExprMutator {
for (size_t i = 0; i < pipeline_stages.size(); i++) { for (size_t i = 0; i < pipeline_stages.size(); i++) {
int stage = static_cast<int>(pipeline_stages[i]->value); int stage = static_cast<int>(pipeline_stages[i]->value);
bool is_async = pipeline_async_stages.find(stage) != pipeline_async_stages.end(); bool is_async =
PipelineAnnotation stage_order{stage, pipeline_async_stages.find(stage) != pipeline_async_stages.end();
/*order=*/static_cast<int>(pipeline_orders[i]->value), PipelineAnnotation stage_order{
is_async}; stage,
/*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
pipeline_info.emplace(original_order[i], stage_order); pipeline_info.emplace(original_order[i], stage_order);
} }
ValidatePipelineBody(pipeline_info, original_order); ValidatePipelineBody(pipeline_info, original_order);
// Step 4: Rewrite the pipeline body. // Step 4: Rewrite the pipeline body.
Stmt pipeline = Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, GetRef<For>(op), pipeline_info) GetRef<For>(op), pipeline_info)
.BuildPipeline(); .BuildPipeline();
if (const auto* realize = op->body.as<BlockRealizeNode>()) { if (const auto *realize = op->body.as<BlockRealizeNode>()) {
const auto& block = realize->block; const auto &block = realize->block;
for (const auto& buffer : block->alloc_buffers) { for (const auto &buffer : block->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data); buffer_data_to_buffer_.erase(buffer->data);
} }
} }
return pipeline; return pipeline;
} }
Stmt VisitStmt_(const BlockNode* op) final { Stmt VisitStmt_(const BlockNode *op) final {
for (const auto& buffer : op->alloc_buffers) { for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
for (const auto& buffer : op->alloc_buffers) { for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data); buffer_data_to_buffer_.erase(buffer->data);
} }
return std::move(block); return std::move(block);
} }
bool HasPipelineAnnotation(const ForNode* op) const { bool HasPipelineAnnotation(const ForNode *op) const {
auto it1 = op->annotations.find(tir::attr::software_pipeline_stage); auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
auto it2 = op->annotations.find(tir::attr::software_pipeline_order); auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
bool has_stage = it1 != op->annotations.end(); bool has_stage = it1 != op->annotations.end();
...@@ -901,10 +962,12 @@ class PipelineInjector : private StmtExprMutator { ...@@ -901,10 +962,12 @@ class PipelineInjector : private StmtExprMutator {
return true; return true;
} }
if (has_stage) { if (has_stage) {
LOG(FATAL) << "ValueError: Order of the software pipeline is not defined."; LOG(FATAL)
<< "ValueError: Order of the software pipeline is not defined.";
} }
if (has_order) { if (has_order) {
LOG(FATAL) << "ValueError: Stage of the software pipeline is not defined."; LOG(FATAL)
<< "ValueError: Stage of the software pipeline is not defined.";
} }
return false; return false;
} }
...@@ -914,13 +977,13 @@ class PipelineInjector : private StmtExprMutator { ...@@ -914,13 +977,13 @@ class PipelineInjector : private StmtExprMutator {
}; };
/*! /*!
* \brief Transform annotated loops into pipelined one that parallelize producers and consumers. * \brief Transform annotated loops into pipelined one that parallelize
* \return The IR transform pass. * producers and consumers. \return The IR transform pass.
*/ */
tir::transform::Pass InjectSoftwarePipeline() { tir::transform::Pass InjectSoftwarePipeline() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* fptr = f.CopyOnWrite(); auto *fptr = f.CopyOnWrite();
fptr->body = PipelineInjector::Inject(f); fptr->body = PipelineInjector::Inject(f);
fptr->body = ConvertSSA(std::move(fptr->body)); fptr->body = ConvertSSA(std::move(fptr->body));
return f; return f;
...@@ -931,5 +994,5 @@ tir::transform::Pass InjectSoftwarePipeline() { ...@@ -931,5 +994,5 @@ tir::transform::Pass InjectSoftwarePipeline() {
TVM_REGISTER_GLOBAL("tl.transform.InjectSoftwarePipeline") TVM_REGISTER_GLOBAL("tl.transform.InjectSoftwarePipeline")
.set_body_typed(InjectSoftwarePipeline); .set_body_typed(InjectSoftwarePipeline);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -30,11 +30,11 @@ ...@@ -30,11 +30,11 @@
#include <queue> #include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h" #include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "common/loop_fusion_utils.h"
#include "loop_partition.h" #include "loop_partition.h"
#include "loop_vectorize.h" #include "loop_vectorize.h"
#include "common/loop_fusion_utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -49,7 +49,7 @@ struct LayoutInferenceResult { ...@@ -49,7 +49,7 @@ struct LayoutInferenceResult {
}; };
class BufferUseDefCollector : public StmtExprVisitor { class BufferUseDefCollector : public StmtExprVisitor {
public: public:
BufferUseDefCollector() = default; BufferUseDefCollector() = default;
LayoutInferenceResult Run() { LayoutInferenceResult Run() {
...@@ -59,22 +59,27 @@ class BufferUseDefCollector : public StmtExprVisitor { ...@@ -59,22 +59,27 @@ class BufferUseDefCollector : public StmtExprVisitor {
// maintain a bfs queue and infer common layout // maintain a bfs queue and infer common layout
std::queue<int> q; std::queue<int> q;
std::vector<bool> in_queue(num_infer, true); std::vector<bool> in_queue(num_infer, true);
for (int i = 0; i < num_infer; i++) q.push(i); for (int i = 0; i < num_infer; i++)
q.push(i);
auto run_infer_step = [&](int cur_infer_id, InferLevel level, bool update_queue) { auto run_infer_step = [&](int cur_infer_id, InferLevel level,
auto& next = infer_list_[cur_infer_id]; bool update_queue) {
auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id]; auto iter_var = thread_var_vec_[cur_infer_id];
auto updates = next->InferLayout( auto updates = next->InferLayout(
LayoutInferArgs{target_, static_cast<size_t>(*as_const_int(iter_var->dom->extent)), LayoutInferArgs{
layout_map}, target_,
static_cast<size_t>(*as_const_int(iter_var->dom->extent)),
layout_map},
level); level);
for (const auto& [buffer, layout] : updates) { for (const auto &[buffer, layout] : updates) {
if (layout_map.count(buffer)) { if (layout_map.count(buffer)) {
ICHECK(StructuralEqual()(layout, layout_map[buffer])) ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer; << "Get different layout for " << buffer;
} else { } else {
layout_map.Set(buffer, layout); layout_map.Set(buffer, layout);
if (!update_queue) continue; if (!update_queue)
continue;
for (int idx : use_list_[buffer]) { for (int idx : use_list_[buffer]) {
if (!in_queue[idx] && idx != cur_infer_id) { if (!in_queue[idx] && idx != cur_infer_id) {
in_queue[idx] = true; in_queue[idx] = true;
...@@ -108,16 +113,17 @@ class BufferUseDefCollector : public StmtExprVisitor { ...@@ -108,16 +113,17 @@ class BufferUseDefCollector : public StmtExprVisitor {
} }
// Check that all fragments have been inferred // Check that all fragments have been inferred
for (const auto& [buffer, _] : use_list_) { for (const auto &[buffer, _] : use_list_) {
if (buffer.scope() == "local.fragment" && layout_map.count(buffer) == 0) if (buffer.scope() == "local.fragment" && layout_map.count(buffer) == 0)
LOG_ERROR << "The layout for fragment " << buffer << " can not be inferred correctly."; LOG_ERROR << "The layout for fragment " << buffer
<< " can not be inferred correctly.";
} }
// Collect the layout for for nodes // Collect the layout for for nodes
Map<For, Fragment> for_map; Map<For, Fragment> for_map;
Map<For, PrimExpr> predicate_map; Map<For, PrimExpr> predicate_map;
for (auto& base_infer : infer_list_) { for (auto &base_infer : infer_list_) {
if (auto for_infer = dynamic_cast<ParallelOp*>(base_infer.get())) { if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
ICHECK(for_infer->GetLoopLayout().defined()) ICHECK(for_infer->GetLoopLayout().defined())
<< "The Layout for Parallel for can not be inferred correctly : \n" << "The Layout for Parallel for can not be inferred correctly : \n"
<< for_infer->GetRoot(); << for_infer->GetRoot();
...@@ -130,25 +136,27 @@ class BufferUseDefCollector : public StmtExprVisitor { ...@@ -130,25 +136,27 @@ class BufferUseDefCollector : public StmtExprVisitor {
return {layout_map, for_map, predicate_map}; return {layout_map, for_map, predicate_map};
} }
void Collect(const PrimFunc& f) { void Collect(const PrimFunc &f) {
for (const auto& [_, buffer] : f->buffer_map) { for (const auto &[_, buffer] : f->buffer_map) {
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "Layout_Inference: Require the target attribute"; ICHECK(target.defined())
<< "Layout_Inference: Require the target attribute";
target_ = target.value(); target_ = target.value();
this->operator()(f->body); this->operator()(f->body);
} }
private: private:
void VisitExpr_(const CallNode* op) final { void VisitExpr_(const CallNode *op) final {
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
// Do not analysis the call node to the global function. // Do not analysis the call node to the global function.
if (op->op.as<GlobalVarNode>()) return; if (op->op.as<GlobalVarNode>())
return;
auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_); auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
if (p != nullptr) { if (p != nullptr) {
for (const auto& arg : op->args) { for (const auto &arg : op->args) {
if (auto buffer = getBufferFromAccessPtr(arg)) { if (auto buffer = getBufferFromAccessPtr(arg)) {
addToUseList(buffer.value()); addToUseList(buffer.value());
} }
...@@ -158,7 +166,7 @@ class BufferUseDefCollector : public StmtExprVisitor { ...@@ -158,7 +166,7 @@ class BufferUseDefCollector : public StmtExprVisitor {
} }
} }
Optional<Buffer> getBufferFromAccessPtr(const PrimExpr& expr) { Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>(); auto call = expr.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_access_ptr())) { if (call && call->op.same_as(builtin::tvm_access_ptr())) {
auto var = call->args[1].as<Var>().value(); auto var = call->args[1].as<Var>().value();
...@@ -167,7 +175,7 @@ class BufferUseDefCollector : public StmtExprVisitor { ...@@ -167,7 +175,7 @@ class BufferUseDefCollector : public StmtExprVisitor {
return NullOpt; return NullOpt;
} }
void addToUseList(const Buffer& buffer) { void addToUseList(const Buffer &buffer) {
int infer_idx = infer_list_.size(); int infer_idx = infer_list_.size();
if (use_list_.find(buffer) == use_list_.end()) { if (use_list_.find(buffer) == use_list_.end()) {
use_list_[buffer] = {}; use_list_[buffer] = {};
...@@ -175,10 +183,10 @@ class BufferUseDefCollector : public StmtExprVisitor { ...@@ -175,10 +183,10 @@ class BufferUseDefCollector : public StmtExprVisitor {
use_list_[buffer].push_back(infer_idx); use_list_[buffer].push_back(infer_idx);
} }
void VisitStmt_(const ForNode* op) final { void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) { if (op->kind == ForKind::kParallel) {
auto infer = std::make_unique<ParallelOp>(GetRef<For>(op)); auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
for (const auto& [buffer, _] : infer->GetIndiceMap()) { for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer); addToUseList(buffer);
} }
infer_list_.push_back(std::move(infer)); infer_list_.push_back(std::move(infer));
...@@ -188,13 +196,14 @@ class BufferUseDefCollector : public StmtExprVisitor { ...@@ -188,13 +196,14 @@ class BufferUseDefCollector : public StmtExprVisitor {
} }
} }
void VisitStmt_(const BlockNode* op) final { void VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) { for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
if (op->annotations.count(attr::kLayoutMap)) { if (op->annotations.count(attr::kLayoutMap)) {
auto map = op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value(); auto map =
for (const auto& [var, layout] : map) { op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
for (const auto &[var, layout] : map) {
auto buffer = buffer_data_to_buffer_[var]; auto buffer = buffer_data_to_buffer_[var];
ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape)); ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
annotated_layout_map_.Set(buffer, layout); annotated_layout_map_.Set(buffer, layout);
...@@ -203,7 +212,7 @@ class BufferUseDefCollector : public StmtExprVisitor { ...@@ -203,7 +212,7 @@ class BufferUseDefCollector : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void VisitStmt_(const AttrStmtNode* op) final { void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") { if (iv->thread_tag == "threadIdx.x") {
...@@ -216,7 +225,8 @@ class BufferUseDefCollector : public StmtExprVisitor { ...@@ -216,7 +225,8 @@ class BufferUseDefCollector : public StmtExprVisitor {
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
std::vector<std::unique_ptr<Operator>> infer_list_; std::vector<std::unique_ptr<Operator>> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual> use_list_; std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
use_list_;
IterVar thread_var_; IterVar thread_var_;
std::vector<IterVar> thread_var_vec_; std::vector<IterVar> thread_var_vec_;
Target target_; Target target_;
...@@ -224,10 +234,10 @@ class BufferUseDefCollector : public StmtExprVisitor { ...@@ -224,10 +234,10 @@ class BufferUseDefCollector : public StmtExprVisitor {
}; };
class LayoutInferencer : public IRMutatorWithAnalyzer { class LayoutInferencer : public IRMutatorWithAnalyzer {
public: public:
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
PrimFuncNode* fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = ParallelLoopFuser::Fuse(f->body); fptr->body = ParallelLoopFuser::Fuse(f->body);
BufferUseDefCollector collector; BufferUseDefCollector collector;
collector.Collect(f); collector.Collect(f);
...@@ -237,11 +247,12 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { ...@@ -237,11 +247,12 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return f; return f;
} }
private: private:
LayoutInferencer(const LayoutInferenceResult result, arith::Analyzer* analyzer) LayoutInferencer(const LayoutInferenceResult result,
: arith::IRMutatorWithAnalyzer(analyzer), result_(result) {}; arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer), result_(result){};
Stmt VisitStmt_(const BlockNode* op) final { Stmt VisitStmt_(const BlockNode *op) final {
Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op)); Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
for (auto buffer : block->alloc_buffers) { for (auto buffer : block->alloc_buffers) {
...@@ -255,11 +266,12 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { ...@@ -255,11 +266,12 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return block; return block;
} }
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode *op) final {
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op)); For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(GetRef<For>(op))) { if (result_.for_map.count(GetRef<For>(op))) {
auto loop_layout = result_.for_map[GetRef<For>(op)]; auto loop_layout = result_.for_map[GetRef<For>(op)];
for_node = PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout); for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
for_node = VectorizeLoop(for_node); for_node = VectorizeLoop(for_node);
if (result_.predicate_map.count(GetRef<For>(op))) { if (result_.predicate_map.count(GetRef<For>(op))) {
return IfThenElse(result_.predicate_map[GetRef<For>(op)], for_node); return IfThenElse(result_.predicate_map[GetRef<For>(op)], for_node);
...@@ -270,7 +282,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { ...@@ -270,7 +282,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return for_node; return for_node;
} }
Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U); ICHECK_NE(iv->thread_tag.length(), 0U);
...@@ -281,7 +293,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { ...@@ -281,7 +293,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return IRMutatorWithAnalyzer::VisitStmt_(op); return IRMutatorWithAnalyzer::VisitStmt_(op);
} }
private: private:
const LayoutInferenceResult result_; const LayoutInferenceResult result_;
IterVar thread_var_; IterVar thread_var_;
}; };
...@@ -297,5 +309,5 @@ tvm::transform::Pass LayoutInference() { ...@@ -297,5 +309,5 @@ tvm::transform::Pass LayoutInference() {
TVM_REGISTER_GLOBAL("tl.transform.LayoutInference") TVM_REGISTER_GLOBAL("tl.transform.LayoutInference")
.set_body_typed(LayoutInference); .set_body_typed(LayoutInference);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -30,8 +30,8 @@ ...@@ -30,8 +30,8 @@
#include <queue> #include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h" #include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h" #include "loop_partition.h"
#include "loop_vectorize.h" #include "loop_vectorize.h"
...@@ -43,11 +43,11 @@ using arith::IRMutatorWithAnalyzer; ...@@ -43,11 +43,11 @@ using arith::IRMutatorWithAnalyzer;
// Helper class to find leaf For nodes in a given IR // Helper class to find leaf For nodes in a given IR
class LeafForFinder : public StmtVisitor { class LeafForFinder : public StmtVisitor {
public: public:
std::vector<For> leaf_for_nodes; std::vector<For> leaf_for_nodes;
private: private:
void VisitStmt_(const ForNode* op) final { void VisitStmt_(const ForNode *op) final {
has_child_for_ = false; has_child_for_ = false;
bool parent_has_child_for = parent_has_child_for_; bool parent_has_child_for = parent_has_child_for_;
parent_has_child_for_ = false; parent_has_child_for_ = false;
...@@ -62,7 +62,7 @@ class LeafForFinder : public StmtVisitor { ...@@ -62,7 +62,7 @@ class LeafForFinder : public StmtVisitor {
parent_has_child_for_ = true; parent_has_child_for_ = true;
} }
private: private:
bool has_child_for_ = false; bool has_child_for_ = false;
bool parent_has_child_for_ = false; bool parent_has_child_for_ = false;
}; };
...@@ -75,11 +75,11 @@ class LeafForFinder : public StmtVisitor { ...@@ -75,11 +75,11 @@ class LeafForFinder : public StmtVisitor {
// If the index might exceed the shape (upper bound too large), // If the index might exceed the shape (upper bound too large),
// log a warning or handle accordingly. // log a warning or handle accordingly.
struct GlobalMemChecker : public StmtExprVisitor { struct GlobalMemChecker : public StmtExprVisitor {
arith::Analyzer* analyzer; arith::Analyzer *analyzer;
explicit GlobalMemChecker(arith::Analyzer* analyzer) : analyzer(analyzer) {} explicit GlobalMemChecker(arith::Analyzer *analyzer) : analyzer(analyzer) {}
void VisitExpr_(const BufferLoadNode* op) final { void VisitExpr_(const BufferLoadNode *op) final {
// Check if the buffer is in global scope // Check if the buffer is in global scope
if (IsGlobalBuffer(op->buffer)) { if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true); CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true);
...@@ -87,7 +87,7 @@ struct GlobalMemChecker : public StmtExprVisitor { ...@@ -87,7 +87,7 @@ struct GlobalMemChecker : public StmtExprVisitor {
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
void VisitStmt_(const BufferStoreNode* op) final { void VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope // Check if the buffer is in global scope
if (IsGlobalBuffer(op->buffer)) { if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false); CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false);
...@@ -96,21 +96,24 @@ struct GlobalMemChecker : public StmtExprVisitor { ...@@ -96,21 +96,24 @@ struct GlobalMemChecker : public StmtExprVisitor {
} }
// Helper function to determine if a buffer is global // Helper function to determine if a buffer is global
bool IsGlobalBuffer(const Buffer& buffer) { bool IsGlobalBuffer(const Buffer &buffer) {
// The storage scope is often encoded in the buffer->data var name or associated attributes. // The storage scope is often encoded in the buffer->data var name or
// In typical TVM IR, global buffers have scope "global". // associated attributes. In typical TVM IR, global buffers have scope
// Here we assume a helper function GetPtrStorageScope is available. // "global". Here we assume a helper function GetPtrStorageScope is
// If not, you might need to parse buffer->data->name_hint or associated attributes. // available. If not, you might need to parse buffer->data->name_hint or
// associated attributes.
String scope = buffer.scope(); String scope = buffer.scope();
return scope == "global"; return scope == "global";
} }
// Check each index against the buffer shape dimensions // Check each index against the buffer shape dimensions
void CheckBufferIndices(const Buffer& buffer, const Array<PrimExpr>& indices, bool is_load) { void CheckBufferIndices(const Buffer &buffer, const Array<PrimExpr> &indices,
bool is_load) {
// Ensure indices count matches buffer dimension // Ensure indices count matches buffer dimension
if (indices.size() != buffer->shape.size()) { if (indices.size() != buffer->shape.size()) {
LOG(WARNING) << "Buffer access dimension mismatch: indices size (" << indices.size() LOG(WARNING) << "Buffer access dimension mismatch: indices size ("
<< ") vs. shape size (" << buffer->shape.size() << ")"; << indices.size() << ") vs. shape size ("
<< buffer->shape.size() << ")";
return; return;
} }
...@@ -130,18 +133,19 @@ struct GlobalMemChecker : public StmtExprVisitor { ...@@ -130,18 +133,19 @@ struct GlobalMemChecker : public StmtExprVisitor {
Array<PrimExpr> GetConditions() { return _conditions; } Array<PrimExpr> GetConditions() { return _conditions; }
private: private:
Array<PrimExpr> _conditions; Array<PrimExpr> _conditions;
}; };
class SafeMemorysRewriter : public StmtExprMutator { class SafeMemorysRewriter : public StmtExprMutator {
arith::Analyzer* analyzer_; arith::Analyzer *analyzer_;
public: public:
explicit SafeMemorysRewriter(arith::Analyzer* analyzer) : analyzer_(analyzer) {} explicit SafeMemorysRewriter(arith::Analyzer *analyzer)
: analyzer_(analyzer) {}
private: private:
Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope // Check if the buffer is in global scope
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
GlobalMemChecker checker(analyzer_); GlobalMemChecker checker(analyzer_);
...@@ -173,12 +177,13 @@ class SafeMemorysRewriter : public StmtExprMutator { ...@@ -173,12 +177,13 @@ class SafeMemorysRewriter : public StmtExprMutator {
// Handle Call Nodes // Handle Call Nodes
// For example // For example
// T.call_extern("handle", "atomicAddx2", T.address_of(C), T.address_of(C_shared)) // T.call_extern("handle", "atomicAddx2", T.address_of(C),
Stmt VisitStmt_(const EvaluateNode* op) final { // T.address_of(C_shared))
Stmt VisitStmt_(const EvaluateNode *op) final {
auto evaluate = Downcast<Evaluate>(StmtExprMutator::VisitStmt_(op)); auto evaluate = Downcast<Evaluate>(StmtExprMutator::VisitStmt_(op));
auto call = Downcast<Call>(evaluate->value); auto call = Downcast<Call>(evaluate->value);
if (call.defined() && call->op == builtin::call_extern()) { if (call.defined() && call->op == builtin::call_extern()) {
GlobalMemChecker checker(analyzer_); GlobalMemChecker checker(analyzer_);
checker(call); checker(call);
Array<PrimExpr> conditions = checker.GetConditions(); Array<PrimExpr> conditions = checker.GetConditions();
...@@ -197,13 +202,12 @@ class SafeMemorysRewriter : public StmtExprMutator { ...@@ -197,13 +202,12 @@ class SafeMemorysRewriter : public StmtExprMutator {
return evaluate; return evaluate;
} }
bool isSharedBuffer(const Buffer &buffer) {
bool isSharedBuffer(const Buffer& buffer) {
String scope = buffer.scope(); String scope = buffer.scope();
return scope == "shared" || scope == "shared.dyn"; return scope == "shared" || scope == "shared.dyn";
} }
bool IsGlobalBuffer(const Buffer& buffer) { bool IsGlobalBuffer(const Buffer &buffer) {
String scope = buffer.scope(); String scope = buffer.scope();
return scope == "global"; return scope == "global";
} }
...@@ -211,32 +215,34 @@ class SafeMemorysRewriter : public StmtExprMutator { ...@@ -211,32 +215,34 @@ class SafeMemorysRewriter : public StmtExprMutator {
// Class to legalize safe memory access by transforming them appropriately // Class to legalize safe memory access by transforming them appropriately
class SafeMemoryLegalizer : IRMutatorWithAnalyzer { class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
public: public:
// Static method to substitute and transform the given PrimFunc // Static method to substitute and transform the given PrimFunc
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
// Create an instance of the legalizer with the analyzer // Create an instance of the legalizer with the analyzer
SafeMemoryLegalizer substituter(&analyzer); SafeMemoryLegalizer substituter(&analyzer);
// Get a mutable copy of the function node // Get a mutable copy of the function node
PrimFuncNode* fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
// Apply the legalizer to the function body // Apply the legalizer to the function body
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
return f; return f;
} }
private: private:
// Constructor initializing the base class with the analyzer // Constructor initializing the base class with the analyzer
SafeMemoryLegalizer(arith::Analyzer* analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {} SafeMemoryLegalizer(arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer) {}
// Override the VisitStmt_ method to handle ForNode (loop statements) // Override the VisitStmt_ method to handle ForNode (loop statements)
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode *op) final {
// Visit and potentially modify the loop node // Visit and potentially modify the loop node
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op)); For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto has_inner_loop = HasInnerLoop(for_node->body); auto has_inner_loop = HasInnerLoop(for_node->body);
if (!has_inner_loop) { if (!has_inner_loop) {
SafeMemorysRewriter rewriter(analyzer_); SafeMemorysRewriter rewriter(analyzer_);
for_node.CopyOnWrite()->body = rewriter(for_node->body); for_node.CopyOnWrite()->body = rewriter(for_node->body);
// // Detect Buffer Load Node in the loop body, collect the indices and buffer size // // Detect Buffer Load Node in the loop body, collect the indices and
// buffer size
// // Run the checker on the loop body // // Run the checker on the loop body
// GlobalMemChecker checker(analyzer_); // GlobalMemChecker checker(analyzer_);
...@@ -257,7 +263,7 @@ class SafeMemoryLegalizer : IRMutatorWithAnalyzer { ...@@ -257,7 +263,7 @@ class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
return IRMutatorWithAnalyzer::VisitStmt_(op); return IRMutatorWithAnalyzer::VisitStmt_(op);
} }
static bool HasInnerLoop(const Stmt& stmt) { static bool HasInnerLoop(const Stmt &stmt) {
LeafForFinder finder; LeafForFinder finder;
finder(stmt); finder(stmt);
return finder.leaf_for_nodes.size() > 0; return finder.leaf_for_nodes.size() > 0;
...@@ -279,5 +285,5 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() { ...@@ -279,5 +285,5 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
TVM_REGISTER_GLOBAL("tl.transform.LegalizeSafeMemoryAccess") TVM_REGISTER_GLOBAL("tl.transform.LegalizeSafeMemoryAccess")
.set_body_typed(LegalizeSafeMemoryAccess); .set_body_typed(LegalizeSafeMemoryAccess);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment