Unverified Commit f5d9da46 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Phaseout vmap for Tile Operators (#1334)



* Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse.

* lint fix

* Refactor region handling by removing the RegionOp and updating NormalizeToBufferRegion to only accept BufferLoad and BufferRegion. This change improves code organization and simplifies the handling of memory regions across various operations.

* fix

* Refactor memory region handling by introducing `tl.region` calls across various operations, including GEMM and fill functions. This change enhances the consistency of region management and improves code organization by utilizing utility functions for buffer region conversions.

* fix

* fix

* test fix

* lint fix

* Refactor GEMM operations to improve memory region handling by replacing `mbarPtr_` with `mbarRegion_` and updating related logic in both C++ and Python implementations. This change enhances the clarity and consistency of buffer region management.

* fix

* lint fix

* fix

* fix

* test fix

* lint fix

* lint fix

* minor fix

* fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent fac04006
import tilelang.testing import tilelang.testing
import example_mla_decode import example_mla_decode
......
...@@ -334,14 +334,14 @@ def get_autotuned_kernel( ...@@ -334,14 +334,14 @@ def get_autotuned_kernel(
return main return main
def check_correctness_and_bench(kernel, N, K, bench_ref=True): def check_correctness_and_bench(kernel, N, K, do_bench=True):
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2) profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2)
if bench_ref: if do_bench:
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=50) latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=50)
print(f"Torch Latency: {latency} ms") print(f"Torch Latency: {latency} ms")
latency = profiler.do_bench(kernel, warmup=50) latency = profiler.do_bench(kernel, warmup=50)
print(f"TileLang Latency: {latency} ms\n") print(f"TileLang Latency: {latency} ms\n")
def main(do_bench: bool = True): def main(do_bench: bool = True):
...@@ -350,12 +350,13 @@ def main(do_bench: bool = True): ...@@ -350,12 +350,13 @@ def main(do_bench: bool = True):
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
N, K = args.n, args.k N, K = args.n, args.k
check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K) check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K, do_bench=do_bench)
check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K) check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K) check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K) check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K) check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K) check_correctness_and_bench(
gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench)
print("Test passed!") print("Test passed!")
......
import tilelang.testing
import example_gemv import example_gemv
...@@ -8,4 +6,4 @@ def test_example_gemv(): ...@@ -8,4 +6,4 @@ def test_example_gemv():
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() test_example_gemv()
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
*/ */
#include "./atomic_add.h" #include "./atomic_add.h"
#include "./region.h" #include "utils.h"
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
...@@ -26,32 +26,27 @@ using namespace tir; ...@@ -26,32 +26,27 @@ using namespace tir;
* @brief Construct an AtomicAdd operator from call arguments and a buffer map. * @brief Construct an AtomicAdd operator from call arguments and a buffer map.
* *
* Builds the internal AtomicAddNode, extracts the source and destination * Builds the internal AtomicAddNode, extracts the source and destination
* regions and their backing Buffers from the first two call-style expressions * regions and their backing Buffers from the first two region-style expressions
* in `args` (via RegionOp), and stores them along with their ranges. If a third * in `args` (BufferLoad/BufferRegion), and stores them along with their
* argument is provided, it is interpreted as an integer immediate and stored as * ranges. If a third argument is provided, it is interpreted as an integer
* the node's coalesced width. * immediate and stored as the node's coalesced width.
* *
* @param args Call-style PrimExprs where: * @param args Call-style PrimExprs where:
* - args[0] is the source region call, * - args[0] is the source region call,
* - args[1] is the destination region call, * - args[1] is the destination region call,
* - args[2] (optional) is an IntImm specifying coalesced width. * - args[2] (optional) is an IntImm specifying coalesced width.
* @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
*
* Notes: * Notes:
* - The constructor checks that args[0] and args[1] are CallNodes. * - The constructor checks that args[0] and args[1] are region-compatible.
* - The constructed node is stored in this->data_. * - The constructed node is stored in this->data_.
*/ */
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { AtomicAdd::AtomicAdd(Array<PrimExpr> args) {
ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>(); ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2]; Array<Range> rgs[2];
Buffer bf[2]; Buffer bf[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
auto expr = args[i]; auto region = NormalizeToBufferRegion(args[i]);
auto call = expr.as<CallNode>(); rgs[i] = region->region;
ICHECK(call); bf[i] = region->buffer;
auto region = RegionOp(call->args, vmap);
rgs[i] = region->GetRanges();
bf[i] = region->GetBuffer();
} }
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
...@@ -552,4 +547,4 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) ...@@ -552,4 +547,4 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); } TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -65,7 +65,7 @@ class AtomicAdd : public TileOperator { ...@@ -65,7 +65,7 @@ class AtomicAdd : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator,
AtomicAddNode); AtomicAddNode);
TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap); TVM_DLL AtomicAdd(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/common/loop_parallel_transform_utils.h"
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h" #include "../transform/loop_vectorize.h"
#include "region.h" #include "utils.h"
#include "../target/cuda.h" #include "../target/cuda.h"
#include "../target/utils.h" #include "../target/utils.h"
...@@ -110,36 +110,32 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) { ...@@ -110,36 +110,32 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
/*! /*!
* \brief Construct a Copy operator node from call arguments and a buffer map. * \brief Construct a Copy operator node from call arguments and a buffer map.
* *
* This constructor parses the first two entries of `args` as Call nodes * This constructor parses the first two entries of `args` as regions
* describing source and destination Regions (via RegionOp), extracts their * (BufferLoad/BufferRegion), extracts their Buffers and Ranges, and stores
* Buffers and Ranges, and stores them on the newly created CopyNode. It also * them on the newly created CopyNode. It also
* reads optional arguments: * reads optional arguments:
* - args[2] (IntImm): coalesced width (stored only if > 0), * - args[2] (IntImm): coalesced width (stored only if > 0),
* - args[3] (Bool): disable TMA lowering flag, * - args[3] (Bool): disable TMA lowering flag,
* - args[4] (IntImm): eviction policy. * - args[4] (IntImm): eviction policy.
* *
* Preconditions: * Preconditions:
* - `args` must contain at least two Call-compatible PrimExpr entries * - `args` must contain at least two region-compatible PrimExpr entries
* describing regions; an ICHECK will fail if they are not CallNodes. * (BufferLoad/BufferRegion); ICHECK will fail otherwise.
* *
* @param args Array of PrimExpr where: * @param args Array of PrimExpr where:
* - args[0] is the source Region call, * - args[0] is the source Region call,
* - args[1] is the destination Region call, * - args[1] is the destination Region call,
* - optional args[2..4] are coalesced width, disable_tma, and eviction * - optional args[2..4] are coalesced width, disable_tma, and eviction
* policy. * policy.
* @param vmap BufferMap used to resolve RegionOp buffers and ranges.
*/ */
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) { Copy::Copy(Array<PrimExpr> args) {
ObjectPtr<CopyNode> node = tvm::ffi::make_object<CopyNode>(); ObjectPtr<CopyNode> node = tvm::ffi::make_object<CopyNode>();
Array<Range> rgs[2]; Array<Range> rgs[2];
Buffer bf[2]; Buffer bf[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
auto expr = args[i]; auto region = NormalizeToBufferRegion(args[i]);
auto call = expr.as<CallNode>(); rgs[i] = region->region;
ICHECK(call); bf[i] = region->buffer;
auto region = RegionOp(call->args, vmap);
rgs[i] = region->GetRanges();
bf[i] = region->GetBuffer();
} }
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
...@@ -250,6 +246,7 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer, ...@@ -250,6 +246,7 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const { Array<PrimExpr> extents, int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range; Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list; Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0; size_t idx = 0;
...@@ -302,7 +299,6 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -302,7 +299,6 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for (const auto &iv : loop_vars) for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom); analyzer->Bind(iv->var, iv->dom);
ICHECK(loop_vars.size() <= src_range.size()) ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size() << "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size() << ", src = " << src->name << ", src_range.size() = " << src_range.size() << ", src = " << src->name
...@@ -1729,20 +1725,21 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const { ...@@ -1729,20 +1725,21 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
* GPU intrinsics. * GPU intrinsics.
* *
* @param args Array of PrimExpr TL-call arguments (see list above). * @param args Array of PrimExpr TL-call arguments (see list above).
* @param vmap Mapping from original buffer variables to actual Buffer objects.
*/ */
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) { Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args) {
ObjectPtr<Conv2DIm2ColOpNode> node = ObjectPtr<Conv2DIm2ColOpNode> node =
tvm::ffi::make_object<Conv2DIm2ColOpNode>(); tvm::ffi::make_object<Conv2DIm2ColOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])]; node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dst = vmap[GetVarFromAccessPtr(args[1])]; node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->nhw_step = args[2]; node->src_ = node->srcRegion_->buffer;
node->c_step = args[3]; node->dst_ = node->dstRegion_->buffer;
node->kernel = args[4].as<IntImm>().value()->value; node->nhw_step_ = args[2];
node->stride = args[5].as<IntImm>().value()->value; node->c_step_ = args[3];
node->dilation = args[6].as<IntImm>().value()->value; node->kernel_ = args[4].as<IntImm>().value()->value;
node->padding = args[7].as<IntImm>().value()->value; node->stride_ = args[5].as<IntImm>().value()->value;
node->eviction_policy = args[8].as<IntImm>().value()->value; node->dilation_ = args[6].as<IntImm>().value()->value;
node->padding_ = args[7].as<IntImm>().value()->value;
node->eviction_policy_ = args[8].as<IntImm>().value()->value;
data_ = std::move(node); data_ = std::move(node);
} }
...@@ -1793,24 +1790,24 @@ TileOperator Conv2DIm2ColOpNode::Clone() const { ...@@ -1793,24 +1790,24 @@ TileOperator Conv2DIm2ColOpNode::Clone() const {
Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer) const {
ICHECK(TargetIsHopper(T.target)); ICHECK(TargetIsHopper(T.target));
ICHECK(src.scope() == "global" && ICHECK(src_.scope() == "global" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared")); (dst_.scope() == "shared.dyn" || dst_.scope() == "shared"));
ICHECK(src->shape.size() == 4); ICHECK(src_->shape.size() == 4);
ICHECK(dst->shape.size() == 2); ICHECK(dst_->shape.size() == 2);
ICHECK(src->dtype == dst->dtype); ICHECK(src_->dtype == dst_->dtype);
Layout shared_layout; Layout shared_layout;
if (T.layout_map.count(dst)) { if (T.layout_map.count(dst_)) {
shared_layout = T.layout_map[dst]; shared_layout = T.layout_map[dst_];
} }
TMAIm2ColDesc desc; TMAIm2ColDesc desc;
desc.rank = src->shape.size(); desc.rank = src_->shape.size();
desc.data_type = to_CUtensorMapDataType(src->dtype); desc.data_type = to_CUtensorMapDataType(src_->dtype);
desc.global_addr = src->data; desc.global_addr = src_->data;
desc.global_shape = ReverseArray(src->shape); desc.global_shape = ReverseArray(src_->shape);
if (!src->strides.empty()) { if (!src_->strides.empty()) {
desc.global_stride = ReverseArray(src->strides); desc.global_stride = ReverseArray(src_->strides);
} else { } else {
// Create stride from shape // Create stride from shape
PrimExpr stride = 1; PrimExpr stride = 1;
...@@ -1824,13 +1821,13 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, ...@@ -1824,13 +1821,13 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes // Make global stride in bytes
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { desc.global_stride = desc.global_stride.Map([&](PrimExpr e) {
return cast(DataType::Int(64), e) * src->dtype.bytes(); return cast(DataType::Int(64), e) * src_->dtype.bytes();
}); });
desc.elem_stride = {1, stride, stride, 1}; desc.elem_stride = {1, stride_, stride_, 1};
desc.lower_corner = {-padding, -padding}; desc.lower_corner = {-padding_, -padding_};
desc.upper_corner = {-padding, -padding}; desc.upper_corner = {-padding_, -padding_};
desc.smem_box_pixel = Downcast<IntImm>(dst->shape[0])->value; desc.smem_box_pixel = Downcast<IntImm>(dst_->shape[0])->value;
desc.smem_box_channel = Downcast<IntImm>(dst->shape[1])->value; desc.smem_box_channel = Downcast<IntImm>(dst_->shape[1])->value;
desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE); desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
...@@ -1844,15 +1841,15 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, ...@@ -1844,15 +1841,15 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
if (StructuralEqual()(shared_layout, if (StructuralEqual()(shared_layout,
makeQuarterBankSwizzleLayout(*stride, *continuous, makeQuarterBankSwizzleLayout(*stride, *continuous,
dst->dtype.bits()))) { dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B);
} else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout( } else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout(
*stride, *continuous, *stride, *continuous,
dst->dtype.bits()))) { dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout( } else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout(
*stride, *continuous, *stride, *continuous,
dst->dtype.bits()))) { dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
} else { } else {
ICHECK(0) << "Cannot detect TMA layout."; ICHECK(0) << "Cannot detect TMA layout.";
...@@ -1871,43 +1868,43 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, ...@@ -1871,43 +1868,43 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
<< "Currently can only support divisible channel case"; << "Currently can only support divisible channel case";
global_coords.push_back( global_coords.push_back(
FloorMod(c_step * desc.smem_box_channel, desc.global_shape[0])); FloorMod(c_step_ * desc.smem_box_channel, desc.global_shape[0]));
image_offset.push_back( image_offset.push_back(
dilation * dilation_ *
FloorMod(FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0]), FloorMod(FloorDiv(c_step_ * desc.smem_box_channel, desc.global_shape[0]),
kernel)); kernel_));
image_offset.push_back(dilation * FloorDiv(c_step * desc.smem_box_channel, image_offset.push_back(dilation_ * FloorDiv(c_step_ * desc.smem_box_channel,
desc.global_shape[0] * kernel)); desc.global_shape[0] * kernel_));
PrimExpr h_dim = PrimExpr h_dim =
FloorDiv(src->shape[1] + 2 * padding - (kernel - 1) * dilation - 1, FloorDiv(src_->shape[1] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1,
stride) + stride_) +
1; 1;
PrimExpr w_dim = PrimExpr w_dim =
FloorDiv(src->shape[2] + 2 * padding - (kernel - 1) * dilation - 1, FloorDiv(src_->shape[2] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1,
stride) + stride_) +
1; 1;
global_coords.push_back( global_coords.push_back(
stride * FloorMod(nhw_step * desc.smem_box_pixel, w_dim) - padding); stride_ * FloorMod(nhw_step_ * desc.smem_box_pixel, w_dim) - padding_);
global_coords.push_back( global_coords.push_back(
stride * stride_ *
FloorMod(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim), h_dim) - FloorMod(FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim), h_dim) -
padding); padding_);
global_coords.push_back( global_coords.push_back(
FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim)); FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim * h_dim));
Array<PrimExpr> args; Array<PrimExpr> args;
args.reserve(desc.rank * 2 + 2); args.reserve(desc.rank * 2 + 2);
args.push_back(create_desc); args.push_back(create_desc);
args.push_back(0); // mbar placeholder args.push_back(0); // mbar placeholder
auto dst_buffer = T.buffer_remap.count(dst) ? T.buffer_remap[dst] : dst; auto dst_buffer = T.buffer_remap.count(dst_) ? T.buffer_remap[dst_] : dst_;
auto shared_addr = dst_buffer.access_ptr(2); auto shared_addr = dst_buffer.access_ptr(2);
args.push_back(shared_addr); args.push_back(shared_addr);
for (auto coord : global_coords) for (auto coord : global_coords)
args.push_back(coord); args.push_back(coord);
for (auto offset : image_offset) for (auto offset : image_offset)
args.push_back(offset); args.push_back(offset);
args.push_back(this->eviction_policy); args.push_back(this->eviction_policy_);
Stmt tma_copy = Stmt tma_copy =
IfThenElse(EQ(T.thread_var, T.thread_bounds->min), IfThenElse(EQ(T.thread_var, T.thread_bounds->min),
Evaluate(Call(DataType::Handle(), tma_load_im2col(), args))); Evaluate(Call(DataType::Handle(), tma_load_im2col(), args)));
......
...@@ -280,7 +280,7 @@ public: ...@@ -280,7 +280,7 @@ public:
* \param args Expression arguments for the copy. * \param args Expression arguments for the copy.
* \param vmap Buffer variable mapping. * \param vmap Buffer variable mapping.
*/ */
TVM_DLL Copy(Array<PrimExpr> args, BufferMap vmap); TVM_DLL Copy(Array<PrimExpr> args);
/*! /*!
* \brief Get the TVM Op handle corresponding to this Copy op. * \brief Get the TVM Op handle corresponding to this Copy op.
...@@ -296,14 +296,16 @@ public: ...@@ -296,14 +296,16 @@ public:
*/ */
class Conv2DIm2ColOpNode : public TileOperatorNode { class Conv2DIm2ColOpNode : public TileOperatorNode {
public: public:
Buffer src, dst; // Source (input feature map) and destination (im2col matrix) BufferRegion srcRegion_, dstRegion_;
int stride; // Stride for convolution Buffer src_,
int padding; // Padding amount dst_; // Source (input feature map) and destination (im2col matrix)
int dilation; // Dilation factor int stride_; // Stride for convolution
int kernel; // Kernel size int padding_; // Padding amount
int eviction_policy; // Cache eviction policy int dilation_; // Dilation factor
PrimExpr nhw_step; // Step size in NHW dimensions int kernel_; // Kernel size
PrimExpr c_step; // Step size in channel dimension int eviction_policy_; // Cache eviction policy
PrimExpr nhw_step_; // Step size in NHW dimensions
PrimExpr c_step_; // Step size in channel dimension
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode, TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode,
TileOperatorNode); TileOperatorNode);
...@@ -311,13 +313,15 @@ public: ...@@ -311,13 +313,15 @@ public:
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<Conv2DIm2ColOpNode>() refl::ObjectDef<Conv2DIm2ColOpNode>()
.def_ro("src", &Conv2DIm2ColOpNode::src) .def_ro("srcRegion", &Conv2DIm2ColOpNode::srcRegion_)
.def_ro("dst", &Conv2DIm2ColOpNode::dst) .def_ro("dstRegion", &Conv2DIm2ColOpNode::dstRegion_)
.def_ro("stride", &Conv2DIm2ColOpNode::stride) .def_ro("src", &Conv2DIm2ColOpNode::src_)
.def_ro("padding", &Conv2DIm2ColOpNode::padding) .def_ro("dst", &Conv2DIm2ColOpNode::dst_)
.def_ro("dilation", &Conv2DIm2ColOpNode::dilation) .def_ro("stride", &Conv2DIm2ColOpNode::stride_)
.def_ro("kernel", &Conv2DIm2ColOpNode::kernel) .def_ro("padding", &Conv2DIm2ColOpNode::padding_)
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy); .def_ro("dilation", &Conv2DIm2ColOpNode::dilation_)
.def_ro("kernel", &Conv2DIm2ColOpNode::kernel_)
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy_);
} }
/*! /*!
...@@ -342,7 +346,7 @@ class Conv2DIm2ColOp : public TileOperator { ...@@ -342,7 +346,7 @@ class Conv2DIm2ColOp : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator,
Conv2DIm2ColOpNode); Conv2DIm2ColOpNode);
TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap); TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h" #include "../transform/loop_vectorize.h"
#include "builtin.h" #include "builtin.h"
#include "region.h" #include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -52,62 +52,18 @@ using namespace tir; ...@@ -52,62 +52,18 @@ using namespace tir;
* value]. * value].
* - args[0]: destination access (BufferLoad or pointer expression). * - args[0]: destination access (BufferLoad or pointer expression).
* - args[1]: value to fill (scalar or vector). * - args[1]: value to fill (scalar or vector).
* @param vmap Mapping from buffer variables to Buffer objects; used to resolve
* the destination when args[0] is not a BufferLoad.
* *
* Notes: * Notes:
* - The constructor enforces constraints (e.g., stride == 1 ramps, constant * - The constructor enforces constraints (e.g., stride == 1 ramps, constant
* lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out * lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out
* of bounds. * of bounds.
*/ */
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { Fill::Fill(Array<PrimExpr> args) {
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>(); ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
// Case 1: Region descriptor call (tl.region) BufferRegion region = NormalizeToBufferRegion(args[0]);
if (const auto *call = args[0].as<CallNode>()) { node->dst = region->buffer;
if (call->op.same_as(RegionOp::Get())) { node->region = region->region;
auto region = RegionOp(call->args, vmap);
node->dst = region->GetBuffer();
node->region = region->GetRanges();
} else if (call->op.same_as(builtin::tvm_access_ptr())) {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
} else {
ICHECK(false) << "Unsupported call op in tl.fill: "
<< Downcast<Op>(call->op)->name;
}
// Case 2: Explicit BufferRegion (legacy path)
} else if (args[0]->IsInstance<BufferRegionNode>()) {
auto region = Downcast<BufferRegion>(args[0]);
node->dst = region->buffer;
node->region = region->region;
// Case 3: Vector/scalar region expressed via BufferLoad indices
} else if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
for (const auto &index : buffer_load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
CHECK(ramp->stride.as<IntImmNode>()->value == 1)
<< "Only stride 1 ramps are supported";
const auto *lanes = ramp->lanes.as<IntImmNode>();
CHECK(lanes)
<< "Scalable vectors not supported in BufferRegion conversion";
node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
node->region.push_back(Range::FromMinExtent(index, 1));
}
}
node->dst = buffer_load->buffer;
// Case 4: Access pointer, fill the full buffer
} else {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
}
if (args[1]->dtype != node->dst->dtype) { if (args[1]->dtype != node->dst->dtype) {
node->value = Cast(node->dst->dtype, args[1]); node->value = Cast(node->dst->dtype, args[1]);
......
...@@ -45,7 +45,7 @@ private: ...@@ -45,7 +45,7 @@ private:
class Fill : public TileOperator { class Fill : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode);
TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap); TVM_DLL Fill(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include "../target/utils.h" #include "../target/utils.h"
#include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -29,12 +30,14 @@ using namespace tir; ...@@ -29,12 +30,14 @@ using namespace tir;
* @param args TL operator arguments: expects at least two elements where * @param args TL operator arguments: expects at least two elements where
* `args[0]` is an access pointer identifying the reducer variable * `args[0]` is an access pointer identifying the reducer variable
* and `args[1]` is an integer encoding a `ReducerOpType` (e.g., Sum/Max/Min). * and `args[1]` is an integer encoding a `ReducerOpType` (e.g., Sum/Max/Min).
* @param vmap Mapping from variables to Buffers used to look up the reducer
* Buffer.
*/ */
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) { FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args) {
auto node = tvm::ffi::make_object<FinalizeReducerOpNode>(); auto node = tvm::ffi::make_object<FinalizeReducerOpNode>();
node->reducer = vmap[GetVarFromAccessPtr(args[0])]; // Normalize any supported region expression
// (BufferRegion/BufferLoad/tl.region) to a BufferRegion, then take the
// underlying Buffer as reducer.
auto region = NormalizeToBufferRegion(args[0]);
node->reducer = region->buffer;
node->op = (ReducerOpType)*as_const_int(args[1]); node->op = (ReducerOpType)*as_const_int(args[1]);
data_ = std::move(node); data_ = std::move(node);
} }
......
...@@ -48,7 +48,7 @@ class FinalizeReducerOp : public TileOperator { ...@@ -48,7 +48,7 @@ class FinalizeReducerOp : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator,
FinalizeReducerOpNode); FinalizeReducerOpNode);
TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap); TVM_DLL FinalizeReducerOp(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../target/utils.h" #include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h" #include "tcgen5_meta.h"
#include "utils.h" #include "utils.h"
...@@ -42,8 +41,6 @@ using namespace tir; ...@@ -42,8 +41,6 @@ using namespace tir;
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)] * (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
* *
* @note If `kPack` is provided it must be 1; otherwise the constructor * @note If `kPack` is provided it must be 1; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is * fails with an ICHECK (runtime assertion). No other validation is
...@@ -53,12 +50,12 @@ using namespace tir; ...@@ -53,12 +50,12 @@ using namespace tir;
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { Gemm::Gemm(Array<PrimExpr> args) {
ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>(); ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0], vmap); node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->bRegion_ = NormalizeToBufferRegion(args[1], vmap); node->bRegion_ = NormalizeToBufferRegion(args[1]);
node->cRegion_ = NormalizeToBufferRegion(args[2], vmap); node->cRegion_ = NormalizeToBufferRegion(args[2]);
node->a_ = node->aRegion_->buffer; node->a_ = node->aRegion_->buffer;
node->b_ = node->bRegion_->buffer; node->b_ = node->bRegion_->buffer;
...@@ -83,11 +80,14 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { ...@@ -83,11 +80,14 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) { if (args.size() > 15) {
node->wgWait_ = args[15].as<IntImm>().value()->value; node->wgWait_ = args[15].as<IntImm>().value()->value;
} }
node->mbarPtr_ = args[16]; if (args.size() > 16) {
if (node->mbarPtr_.as<CallNode>()) { if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)]; node->mbarRegion_ =
} else { NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = std::nullopt; node->mbar_ = node->mbarRegion_->buffer;
} else {
node->mbar_ = std::nullopt;
}
} }
node->cCoords_ = Array<PrimExpr>( node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()}); {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
...@@ -500,11 +500,13 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -500,11 +500,13 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_; auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_;
Array<PrimExpr> new_args; Array<PrimExpr> new_args;
auto mbarPtr =
MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true);
new_args.push_back(StringImm(ss.str())); new_args.push_back(StringImm(ss.str()));
new_args.push_back(Aptr); new_args.push_back(Aptr);
new_args.push_back(Bptr); new_args.push_back(Bptr);
new_args.push_back(BufferLoad(C_buffer, cCoords_)); new_args.push_back(BufferLoad(C_buffer, cCoords_));
new_args.push_back(mbarPtr_); new_args.push_back(mbarPtr);
new_args.push_back(clearAccum_); new_args.push_back(clearAccum_);
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args); auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
......
...@@ -97,7 +97,7 @@ public: ...@@ -97,7 +97,7 @@ public:
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
int kPack_ = 1; int kPack_ = 1;
int wgWait_ = 0; int wgWait_ = 0;
PrimExpr mbarPtr_; BufferRegion mbarRegion_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_; Array<PrimExpr> cCoords_;
mutable GemmWarpPolicy policy_; mutable GemmWarpPolicy policy_;
...@@ -144,7 +144,7 @@ private: ...@@ -144,7 +144,7 @@ private:
class Gemm : public TileOperator { class Gemm : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode);
TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap); TVM_DLL Gemm(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../target/utils.h" #include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h" #include "tcgen5_meta.h"
#include "utils.h" #include "utils.h"
...@@ -46,19 +45,17 @@ using namespace tir; ...@@ -46,19 +45,17 @@ using namespace tir;
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)] * (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
* *
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is * fails with an ICHECK (runtime assertion). No other validation is
* performed here. * performed here.
*/ */
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { GemmPy::GemmPy(Array<PrimExpr> args) {
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>(); ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0], vmap); node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->bRegion_ = NormalizeToBufferRegion(args[1], vmap); node->bRegion_ = NormalizeToBufferRegion(args[1]);
node->cRegion_ = NormalizeToBufferRegion(args[2], vmap); node->cRegion_ = NormalizeToBufferRegion(args[2]);
node->a_ = node->aRegion_->buffer; node->a_ = node->aRegion_->buffer;
node->b_ = node->bRegion_->buffer; node->b_ = node->bRegion_->buffer;
...@@ -83,11 +80,12 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { ...@@ -83,11 +80,12 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) { if (args.size() > 15) {
node->wgWait_ = args[15].as<IntImm>().value()->value; node->wgWait_ = args[15].as<IntImm>().value()->value;
} }
node->mbarPtr_ = args[16]; if (args.size() > 16) {
if (node->mbarPtr_.as<CallNode>()) { if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)]; node->mbarRegion_ =
} else { NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = std::nullopt; node->mbar_ = node->mbarRegion_->buffer;
}
} }
node->cCoords_ = Array<PrimExpr>( node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()}); {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
......
...@@ -29,8 +29,8 @@ public: ...@@ -29,8 +29,8 @@ public:
int strideA_, strideB_; int strideA_, strideB_;
int offsetA_, offsetB_; int offsetA_, offsetB_;
PrimExpr clearAccum_ = const_false(); PrimExpr clearAccum_ = const_false();
PrimExpr mbarPtr_; BufferRegion mbarRegion_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_; Array<PrimExpr> cCoords_;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
...@@ -59,7 +59,8 @@ public: ...@@ -59,7 +59,8 @@ public:
.def_ro("offsetA", &GemmPyNode::offsetA_) .def_ro("offsetA", &GemmPyNode::offsetA_)
.def_ro("offsetB", &GemmPyNode::offsetB_) .def_ro("offsetB", &GemmPyNode::offsetB_)
.def_ro("clearAccum", &GemmPyNode::clearAccum_) .def_ro("clearAccum", &GemmPyNode::clearAccum_)
.def_ro("mbarPtr", &GemmPyNode::mbarPtr_) .def_ro("mbarRegion", &GemmPyNode::mbarRegion_)
.def_ro("mbar", &GemmPyNode::mbar_)
.def_ro("cCoords", &GemmPyNode::cCoords_) .def_ro("cCoords", &GemmPyNode::cCoords_)
.def_ro("kPack", &GemmPyNode::kPack_) .def_ro("kPack", &GemmPyNode::kPack_)
.def_ro("wgWait", &GemmPyNode::wgWait_) .def_ro("wgWait", &GemmPyNode::wgWait_)
...@@ -82,7 +83,7 @@ private: ...@@ -82,7 +83,7 @@ private:
class GemmPy : public TileOperator { class GemmPy : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode);
TVM_DLL GemmPy(Array<PrimExpr> args, BufferMap vmap); TVM_DLL GemmPy(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "../target/utils.h" #include "../target/utils.h"
#include "builtin.h" #include "builtin.h"
#include "gemm.h" #include "gemm.h"
#include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -79,16 +80,19 @@ std::pair<int, int> GemmSPWarpPolicyNode::computeWarpPartition(int M, int N, ...@@ -79,16 +80,19 @@ std::pair<int, int> GemmSPWarpPolicyNode::computeWarpPartition(int M, int N,
* The populated GemmSPNode is stored in the instance's internal data_ pointer. * The populated GemmSPNode is stored in the instance's internal data_ pointer.
* *
* @param args Positional TL call arguments in the above order. * @param args Positional TL call arguments in the above order.
* @param vmap BufferMap mapping access pointers (from args) to Buffer objects.
* *
* @note An ICHECK failure is raised if a provided kPack is not 1 or 2. * @note An ICHECK failure is raised if a provided kPack is not 1 or 2.
*/ */
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) { GemmSP::GemmSP(Array<PrimExpr> args) {
ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>(); ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>();
node->a_ = vmap[GetVarFromAccessPtr(args[0])]; node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->e_ = vmap[GetVarFromAccessPtr(args[1])]; node->eRegion_ = NormalizeToBufferRegion(args[1]);
node->b_ = vmap[GetVarFromAccessPtr(args[2])]; node->bRegion_ = NormalizeToBufferRegion(args[2]);
node->c_ = vmap[GetVarFromAccessPtr(args[3])]; node->cRegion_ = NormalizeToBufferRegion(args[3]);
node->a_ = node->aRegion_->buffer;
node->e_ = node->eRegion_->buffer;
node->b_ = node->bRegion_->buffer;
node->c_ = node->cRegion_->buffer;
node->transA_ = args[4].as<Bool>().value(); node->transA_ = args[4].as<Bool>().value();
node->transB_ = args[5].as<Bool>().value(); node->transB_ = args[5].as<Bool>().value();
node->m_ = args[6].as<IntImm>().value()->value; node->m_ = args[6].as<IntImm>().value()->value;
......
...@@ -53,6 +53,7 @@ public: ...@@ -53,6 +53,7 @@ public:
class GemmSPNode : public TileOperatorNode { class GemmSPNode : public TileOperatorNode {
public: public:
BufferRegion aRegion_, bRegion_, cRegion_, eRegion_;
tir::Buffer a_, b_, c_, e_; tir::Buffer a_, b_, c_, e_;
bool transA_, transB_; bool transA_, transB_;
int m_, n_, k_; int m_, n_, k_;
...@@ -75,6 +76,10 @@ public: ...@@ -75,6 +76,10 @@ public:
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPNode>() refl::ObjectDef<GemmSPNode>()
.def_ro("policy", &GemmSPNode::policy_) .def_ro("policy", &GemmSPNode::policy_)
.def_ro("aRegion", &GemmSPNode::aRegion_)
.def_ro("bRegion", &GemmSPNode::bRegion_)
.def_ro("cRegion", &GemmSPNode::cRegion_)
.def_ro("eRegion", &GemmSPNode::eRegion_)
.def_ro("a", &GemmSPNode::a_) .def_ro("a", &GemmSPNode::a_)
.def_ro("b", &GemmSPNode::b_) .def_ro("b", &GemmSPNode::b_)
.def_ro("c", &GemmSPNode::c_) .def_ro("c", &GemmSPNode::c_)
...@@ -96,7 +101,7 @@ private: ...@@ -96,7 +101,7 @@ private:
class GemmSP : public TileOperator { class GemmSP : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode);
TVM_DLL GemmSP(Array<PrimExpr> args, BufferMap vmap); TVM_DLL GemmSP(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -24,16 +24,14 @@ using namespace tir; ...@@ -24,16 +24,14 @@ using namespace tir;
* *
* @param call The TIR Call whose operator and arguments will be used to build * @param call The TIR Call whose operator and arguments will be used to build
* the TileOperator. * the TileOperator.
* @param vmap Buffer mapping passed through to the builder to resolve buffer
* references.
* @return TileOperator The constructed TileOperator, or a default (empty) * @return TileOperator The constructed TileOperator, or a default (empty)
* TileOperator if no builder exists. * TileOperator if no builder exists.
*/ */
TileOperator ParseOperator(Call call, BufferMap vmap) { TileOperator ParseOperator(Call call) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder"); auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value(); Op op = call->op.as<Op>().value();
if (op_map.count(op)) { if (op_map.count(op)) {
auto tile_op = op_map[op](call->args, vmap); auto tile_op = op_map[op](call->args);
ICHECK(tile_op.defined()); ICHECK(tile_op.defined());
return tile_op; return tile_op;
} }
...@@ -48,14 +46,13 @@ TileOperator ParseOperator(Call call, BufferMap vmap) { ...@@ -48,14 +46,13 @@ TileOperator ParseOperator(Call call, BufferMap vmap) {
* Otherwise returns a default-constructed (empty) TileOperator. * Otherwise returns a default-constructed (empty) TileOperator.
* *
* @param stmt TIR statement to inspect; expected to be an Evaluate of a Call. * @param stmt TIR statement to inspect; expected to be an Evaluate of a Call.
* @param vmap Mapping of buffer variables used when building the operator.
* @return TileOperator Parsed operator on success, or a default (empty) * @return TileOperator Parsed operator on success, or a default (empty)
* TileOperator if `stmt` is not an Evaluate(Call). * TileOperator if `stmt` is not an Evaluate(Call).
*/ */
TileOperator ParseOperator(Stmt stmt, BufferMap vmap) { TileOperator ParseOperator(Stmt stmt) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) { if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>(); auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(tvm::ffi::GetRef<Call>(call), vmap); return ParseOperator(tvm::ffi::GetRef<Call>(call));
} }
return TileOperator(); return TileOperator();
} }
......
...@@ -72,11 +72,10 @@ public: ...@@ -72,11 +72,10 @@ public:
Var GetVarFromAccessPtr(const PrimExpr &expr); Var GetVarFromAccessPtr(const PrimExpr &expr);
TileOperator ParseOperator(Call call, BufferMap vmap); TileOperator ParseOperator(Call call);
TileOperator ParseOperator(Stmt stmt, BufferMap vmap); TileOperator ParseOperator(Stmt stmt);
using OpBuilderFunc = using OpBuilderFunc = ffi::TypedFunction<TileOperator(Array<PrimExpr>)>;
ffi::TypedFunction<TileOperator(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \ #define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op &Entry::Get() { \ const Op &Entry::Get() { \
...@@ -85,10 +84,8 @@ using OpBuilderFunc = ...@@ -85,10 +84,8 @@ using OpBuilderFunc =
} \ } \
TVM_REGISTER_OP("tl." #OpName) \ TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \ .set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>("TLOpBuilder", \ .set_attr<OpBuilderFunc>( \
[](Array<PrimExpr> args, BufferMap vmap) { \ "TLOpBuilder", [](Array<PrimExpr> args) { return Entry(args); })
return Entry(args, vmap); \
})
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "../op/parallel.h" #include "../op/parallel.h"
#include "../target/utils.h" #include "../target/utils.h"
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "region.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
#include "tvm/tir/stmt.h" #include "tvm/tir/stmt.h"
#include "utils.h" #include "utils.h"
...@@ -28,11 +27,11 @@ using namespace tir; ...@@ -28,11 +27,11 @@ using namespace tir;
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ReduceOp::ReduceOp(Array<PrimExpr> args) {
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>(); ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
// Accept BufferRegion/BufferLoad/tl.region for src/dst // Accept BufferRegion/BufferLoad for src/dst
node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap); node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap); node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->src = node->srcRegion_->buffer; node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer; node->dst = node->dstRegion_->buffer;
std::string reduce_type = args[2].as<StringImm>().value()->value; std::string reduce_type = args[2].as<StringImm>().value()->value;
...@@ -494,7 +493,7 @@ static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) { ...@@ -494,7 +493,7 @@ static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) {
return BufferRegion(buf, ranges); return BufferRegion(buf, ranges);
} }
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) { CumSumOp::CumSumOp(Array<PrimExpr> args) {
/// CumSum constructor arguments: /// CumSum constructor arguments:
/// - src: input buffer /// - src: input buffer
/// - dst: output buffer /// - dst: output buffer
...@@ -504,8 +503,8 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -504,8 +503,8 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>(); ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>();
// node->src = vmap[GetVarFromAccessPtr(args[0])]; // node->src = vmap[GetVarFromAccessPtr(args[0])];
// node->dst = vmap[GetVarFromAccessPtr(args[1])]; // node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap); node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap); node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->src = node->srcRegion_->buffer; node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer; node->dst = node->dstRegion_->buffer;
node->dim = args[2].as<IntImm>().value()->value; node->dim = args[2].as<IntImm>().value()->value;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment