Unverified Commit f5d9da46 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Phaseout vmap for Tile Operators (#1334)



* Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse.

* lint fix

* Refactor region handling by removing the RegionOp and updating NormalizeToBufferRegion to only accept BufferLoad and BufferRegion. This change improves code organization and simplifies the handling of memory regions across various operations.

* fix

* Refactor memory region handling by introducing `tl.region` calls across various operations, including GEMM and fill functions. This change enhances the consistency of region management and improves code organization by utilizing utility functions for buffer region conversions.

* fix

* fix

* test fix

* lint fix

* Refactor GEMM operations to improve memory region handling by replacing `mbarPtr_` with `mbarRegion_` and updating related logic in both C++ and Python implementations. This change enhances the clarity and consistency of buffer region management.

* fix

* lint fix

* fix

* fix

* test fix

* lint fix

* lint fix

* minor fix

* fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent fac04006
import tilelang.testing
import example_mla_decode
......
......@@ -334,10 +334,10 @@ def get_autotuned_kernel(
return main
def check_correctness_and_bench(kernel, N, K, bench_ref=True):
def check_correctness_and_bench(kernel, N, K, do_bench=True):
profiler = kernel.get_profiler()
profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2)
if bench_ref:
if do_bench:
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=50)
print(f"Torch Latency: {latency} ms")
latency = profiler.do_bench(kernel, warmup=50)
......@@ -350,12 +350,13 @@ def main(do_bench: bool = True):
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
args, _ = parser.parse_known_args()
N, K = args.n, args.k
check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K)
check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K)
check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K)
check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K)
check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K)
check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K)
check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K, do_bench=do_bench)
check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(
gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench)
print("Test passed!")
......
import tilelang.testing
import example_gemv
......@@ -8,4 +6,4 @@ def test_example_gemv():
if __name__ == "__main__":
tilelang.testing.main()
test_example_gemv()
......@@ -5,7 +5,7 @@
*/
#include "./atomic_add.h"
#include "./region.h"
#include "utils.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
......@@ -26,32 +26,27 @@ using namespace tir;
* @brief Construct an AtomicAdd operator from call arguments and a buffer map.
*
* Builds the internal AtomicAddNode, extracts the source and destination
* regions and their backing Buffers from the first two call-style expressions
* in `args` (via RegionOp), and stores them along with their ranges. If a third
* argument is provided, it is interpreted as an integer immediate and stored as
* the node's coalesced width.
* regions and their backing Buffers from the first two region-style expressions
* in `args` (BufferLoad/BufferRegion), and stores them along with their
* ranges. If a third argument is provided, it is interpreted as an integer
* immediate and stored as the node's coalesced width.
*
* @param args Call-style PrimExprs where:
* - args[0] is the source region call,
* - args[1] is the destination region call,
* - args[2] (optional) is an IntImm specifying coalesced width.
* @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
*
* Notes:
* - The constructor checks that args[0] and args[1] are CallNodes.
* - The constructor checks that args[0] and args[1] are region-compatible.
* - The constructed node is stored in this->data_.
*/
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
AtomicAdd::AtomicAdd(Array<PrimExpr> args) {
ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto expr = args[i];
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region->GetRanges();
bf[i] = region->GetBuffer();
auto region = NormalizeToBufferRegion(args[i]);
rgs[i] = region->region;
bf[i] = region->buffer;
}
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
......
......@@ -65,7 +65,7 @@ class AtomicAdd : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator,
AtomicAddNode);
TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL AtomicAdd(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -16,7 +16,7 @@
#include "../transform/common/loop_parallel_transform_utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "region.h"
#include "utils.h"
#include "../target/cuda.h"
#include "../target/utils.h"
......@@ -110,36 +110,32 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
/*!
* \brief Construct a Copy operator node from call arguments and a buffer map.
*
* This constructor parses the first two entries of `args` as Call nodes
* describing source and destination Regions (via RegionOp), extracts their
* Buffers and Ranges, and stores them on the newly created CopyNode. It also
* This constructor parses the first two entries of `args` as regions
* (BufferLoad/BufferRegion), extracts their Buffers and Ranges, and stores
* them on the newly created CopyNode. It also
* reads optional arguments:
* - args[2] (IntImm): coalesced width (stored only if > 0),
* - args[3] (Bool): disable TMA lowering flag,
* - args[4] (IntImm): eviction policy.
*
* Preconditions:
* - `args` must contain at least two Call-compatible PrimExpr entries
* describing regions; an ICHECK will fail if they are not CallNodes.
* - `args` must contain at least two region-compatible PrimExpr entries
* (BufferLoad/BufferRegion); ICHECK will fail otherwise.
*
* @param args Array of PrimExpr where:
* - args[0] is the source Region call,
* - args[1] is the destination Region call,
* - optional args[2..4] are coalesced width, disable_tma, and eviction
* policy.
* @param vmap BufferMap used to resolve RegionOp buffers and ranges.
*/
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
Copy::Copy(Array<PrimExpr> args) {
ObjectPtr<CopyNode> node = tvm::ffi::make_object<CopyNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto expr = args[i];
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region->GetRanges();
bf[i] = region->GetBuffer();
auto region = NormalizeToBufferRegion(args[i]);
rgs[i] = region->region;
bf[i] = region->buffer;
}
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
......@@ -250,6 +246,7 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0;
......@@ -302,7 +299,6 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size() << ", src = " << src->name
......@@ -1729,20 +1725,21 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
* GPU intrinsics.
*
* @param args Array of PrimExpr TL-call arguments (see list above).
* @param vmap Mapping from original buffer variables to actual Buffer objects.
*/
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args) {
ObjectPtr<Conv2DIm2ColOpNode> node =
tvm::ffi::make_object<Conv2DIm2ColOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->nhw_step = args[2];
node->c_step = args[3];
node->kernel = args[4].as<IntImm>().value()->value;
node->stride = args[5].as<IntImm>().value()->value;
node->dilation = args[6].as<IntImm>().value()->value;
node->padding = args[7].as<IntImm>().value()->value;
node->eviction_policy = args[8].as<IntImm>().value()->value;
node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->src_ = node->srcRegion_->buffer;
node->dst_ = node->dstRegion_->buffer;
node->nhw_step_ = args[2];
node->c_step_ = args[3];
node->kernel_ = args[4].as<IntImm>().value()->value;
node->stride_ = args[5].as<IntImm>().value()->value;
node->dilation_ = args[6].as<IntImm>().value()->value;
node->padding_ = args[7].as<IntImm>().value()->value;
node->eviction_policy_ = args[8].as<IntImm>().value()->value;
data_ = std::move(node);
}
......@@ -1793,24 +1790,24 @@ TileOperator Conv2DIm2ColOpNode::Clone() const {
Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
ICHECK(TargetIsHopper(T.target));
ICHECK(src.scope() == "global" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared"));
ICHECK(src->shape.size() == 4);
ICHECK(dst->shape.size() == 2);
ICHECK(src->dtype == dst->dtype);
ICHECK(src_.scope() == "global" &&
(dst_.scope() == "shared.dyn" || dst_.scope() == "shared"));
ICHECK(src_->shape.size() == 4);
ICHECK(dst_->shape.size() == 2);
ICHECK(src_->dtype == dst_->dtype);
Layout shared_layout;
if (T.layout_map.count(dst)) {
shared_layout = T.layout_map[dst];
if (T.layout_map.count(dst_)) {
shared_layout = T.layout_map[dst_];
}
TMAIm2ColDesc desc;
desc.rank = src->shape.size();
desc.data_type = to_CUtensorMapDataType(src->dtype);
desc.global_addr = src->data;
desc.global_shape = ReverseArray(src->shape);
desc.rank = src_->shape.size();
desc.data_type = to_CUtensorMapDataType(src_->dtype);
desc.global_addr = src_->data;
desc.global_shape = ReverseArray(src_->shape);
if (!src->strides.empty()) {
desc.global_stride = ReverseArray(src->strides);
if (!src_->strides.empty()) {
desc.global_stride = ReverseArray(src_->strides);
} else {
// Create stride from shape
PrimExpr stride = 1;
......@@ -1824,13 +1821,13 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) {
return cast(DataType::Int(64), e) * src->dtype.bytes();
return cast(DataType::Int(64), e) * src_->dtype.bytes();
});
desc.elem_stride = {1, stride, stride, 1};
desc.lower_corner = {-padding, -padding};
desc.upper_corner = {-padding, -padding};
desc.smem_box_pixel = Downcast<IntImm>(dst->shape[0])->value;
desc.smem_box_channel = Downcast<IntImm>(dst->shape[1])->value;
desc.elem_stride = {1, stride_, stride_, 1};
desc.lower_corner = {-padding_, -padding_};
desc.upper_corner = {-padding_, -padding_};
desc.smem_box_pixel = Downcast<IntImm>(dst_->shape[0])->value;
desc.smem_box_channel = Downcast<IntImm>(dst_->shape[1])->value;
desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
......@@ -1844,15 +1841,15 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
if (StructuralEqual()(shared_layout,
makeQuarterBankSwizzleLayout(*stride, *continuous,
dst->dtype.bits()))) {
dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B);
} else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout(
*stride, *continuous,
dst->dtype.bits()))) {
dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout(
*stride, *continuous,
dst->dtype.bits()))) {
dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
} else {
ICHECK(0) << "Cannot detect TMA layout.";
......@@ -1871,43 +1868,43 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
<< "Currently can only support divisible channel case";
global_coords.push_back(
FloorMod(c_step * desc.smem_box_channel, desc.global_shape[0]));
FloorMod(c_step_ * desc.smem_box_channel, desc.global_shape[0]));
image_offset.push_back(
dilation *
FloorMod(FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0]),
kernel));
image_offset.push_back(dilation * FloorDiv(c_step * desc.smem_box_channel,
desc.global_shape[0] * kernel));
dilation_ *
FloorMod(FloorDiv(c_step_ * desc.smem_box_channel, desc.global_shape[0]),
kernel_));
image_offset.push_back(dilation_ * FloorDiv(c_step_ * desc.smem_box_channel,
desc.global_shape[0] * kernel_));
PrimExpr h_dim =
FloorDiv(src->shape[1] + 2 * padding - (kernel - 1) * dilation - 1,
stride) +
FloorDiv(src_->shape[1] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1,
stride_) +
1;
PrimExpr w_dim =
FloorDiv(src->shape[2] + 2 * padding - (kernel - 1) * dilation - 1,
stride) +
FloorDiv(src_->shape[2] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1,
stride_) +
1;
global_coords.push_back(
stride * FloorMod(nhw_step * desc.smem_box_pixel, w_dim) - padding);
stride_ * FloorMod(nhw_step_ * desc.smem_box_pixel, w_dim) - padding_);
global_coords.push_back(
stride *
FloorMod(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim), h_dim) -
padding);
stride_ *
FloorMod(FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim), h_dim) -
padding_);
global_coords.push_back(
FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim));
FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim * h_dim));
Array<PrimExpr> args;
args.reserve(desc.rank * 2 + 2);
args.push_back(create_desc);
args.push_back(0); // mbar placeholder
auto dst_buffer = T.buffer_remap.count(dst) ? T.buffer_remap[dst] : dst;
auto dst_buffer = T.buffer_remap.count(dst_) ? T.buffer_remap[dst_] : dst_;
auto shared_addr = dst_buffer.access_ptr(2);
args.push_back(shared_addr);
for (auto coord : global_coords)
args.push_back(coord);
for (auto offset : image_offset)
args.push_back(offset);
args.push_back(this->eviction_policy);
args.push_back(this->eviction_policy_);
Stmt tma_copy =
IfThenElse(EQ(T.thread_var, T.thread_bounds->min),
Evaluate(Call(DataType::Handle(), tma_load_im2col(), args)));
......
......@@ -280,7 +280,7 @@ public:
* \param args Expression arguments for the copy.
* \param vmap Buffer variable mapping.
*/
TVM_DLL Copy(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL Copy(Array<PrimExpr> args);
/*!
* \brief Get the TVM Op handle corresponding to this Copy op.
......@@ -296,14 +296,16 @@ public:
*/
class Conv2DIm2ColOpNode : public TileOperatorNode {
public:
Buffer src, dst; // Source (input feature map) and destination (im2col matrix)
int stride; // Stride for convolution
int padding; // Padding amount
int dilation; // Dilation factor
int kernel; // Kernel size
int eviction_policy; // Cache eviction policy
PrimExpr nhw_step; // Step size in NHW dimensions
PrimExpr c_step; // Step size in channel dimension
BufferRegion srcRegion_, dstRegion_;
Buffer src_,
dst_; // Source (input feature map) and destination (im2col matrix)
int stride_; // Stride for convolution
int padding_; // Padding amount
int dilation_; // Dilation factor
int kernel_; // Kernel size
int eviction_policy_; // Cache eviction policy
PrimExpr nhw_step_; // Step size in NHW dimensions
PrimExpr c_step_; // Step size in channel dimension
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode,
TileOperatorNode);
......@@ -311,13 +313,15 @@ public:
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<Conv2DIm2ColOpNode>()
.def_ro("src", &Conv2DIm2ColOpNode::src)
.def_ro("dst", &Conv2DIm2ColOpNode::dst)
.def_ro("stride", &Conv2DIm2ColOpNode::stride)
.def_ro("padding", &Conv2DIm2ColOpNode::padding)
.def_ro("dilation", &Conv2DIm2ColOpNode::dilation)
.def_ro("kernel", &Conv2DIm2ColOpNode::kernel)
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy);
.def_ro("srcRegion", &Conv2DIm2ColOpNode::srcRegion_)
.def_ro("dstRegion", &Conv2DIm2ColOpNode::dstRegion_)
.def_ro("src", &Conv2DIm2ColOpNode::src_)
.def_ro("dst", &Conv2DIm2ColOpNode::dst_)
.def_ro("stride", &Conv2DIm2ColOpNode::stride_)
.def_ro("padding", &Conv2DIm2ColOpNode::padding_)
.def_ro("dilation", &Conv2DIm2ColOpNode::dilation_)
.def_ro("kernel", &Conv2DIm2ColOpNode::kernel_)
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy_);
}
/*!
......@@ -342,7 +346,7 @@ class Conv2DIm2ColOp : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator,
Conv2DIm2ColOpNode);
TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -17,7 +17,7 @@
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"
#include "region.h"
#include "utils.h"
namespace tvm {
namespace tl {
......@@ -52,63 +52,19 @@ using namespace tir;
* value].
* - args[0]: destination access (BufferLoad or pointer expression).
* - args[1]: value to fill (scalar or vector).
* @param vmap Mapping from buffer variables to Buffer objects; used to resolve
* the destination when args[0] is not a BufferLoad.
*
* Notes:
* - The constructor enforces constraints (e.g., stride == 1 ramps, constant
* lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out
* of bounds.
*/
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
Fill::Fill(Array<PrimExpr> args) {
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
// Case 1: Region descriptor call (tl.region)
if (const auto *call = args[0].as<CallNode>()) {
if (call->op.same_as(RegionOp::Get())) {
auto region = RegionOp(call->args, vmap);
node->dst = region->GetBuffer();
node->region = region->GetRanges();
} else if (call->op.same_as(builtin::tvm_access_ptr())) {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
} else {
ICHECK(false) << "Unsupported call op in tl.fill: "
<< Downcast<Op>(call->op)->name;
}
// Case 2: Explicit BufferRegion (legacy path)
} else if (args[0]->IsInstance<BufferRegionNode>()) {
auto region = Downcast<BufferRegion>(args[0]);
BufferRegion region = NormalizeToBufferRegion(args[0]);
node->dst = region->buffer;
node->region = region->region;
// Case 3: Vector/scalar region expressed via BufferLoad indices
} else if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
for (const auto &index : buffer_load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
CHECK(ramp->stride.as<IntImmNode>()->value == 1)
<< "Only stride 1 ramps are supported";
const auto *lanes = ramp->lanes.as<IntImmNode>();
CHECK(lanes)
<< "Scalable vectors not supported in BufferRegion conversion";
node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
node->region.push_back(Range::FromMinExtent(index, 1));
}
}
node->dst = buffer_load->buffer;
// Case 4: Access pointer, fill the full buffer
} else {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
}
if (args[1]->dtype != node->dst->dtype) {
node->value = Cast(node->dst->dtype, args[1]);
} else {
......
......@@ -45,7 +45,7 @@ private:
class Fill : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode);
TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL Fill(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -12,6 +12,7 @@
#include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
#include "utils.h"
namespace tvm {
namespace tl {
......@@ -29,12 +30,14 @@ using namespace tir;
* @param args TL operator arguments: expects at least two elements where
* `args[0]` is an access pointer identifying the reducer variable
* and `args[1]` is an integer encoding a `ReducerOpType` (e.g., Sum/Max/Min).
* @param vmap Mapping from variables to Buffers used to look up the reducer
* Buffer.
*/
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args) {
auto node = tvm::ffi::make_object<FinalizeReducerOpNode>();
node->reducer = vmap[GetVarFromAccessPtr(args[0])];
// Normalize any supported region expression
// (BufferRegion/BufferLoad/tl.region) to a BufferRegion, then take the
// underlying Buffer as reducer.
auto region = NormalizeToBufferRegion(args[0]);
node->reducer = region->buffer;
node->op = (ReducerOpType)*as_const_int(args[1]);
data_ = std::move(node);
}
......
......@@ -48,7 +48,7 @@ class FinalizeReducerOp : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator,
FinalizeReducerOpNode);
TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL FinalizeReducerOp(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -12,7 +12,6 @@
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h"
#include "utils.h"
......@@ -42,8 +41,6 @@ using namespace tir;
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
......@@ -53,12 +50,12 @@ using namespace tir;
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
Gemm::Gemm(Array<PrimExpr> args) {
ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->bRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->cRegion_ = NormalizeToBufferRegion(args[2], vmap);
node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->bRegion_ = NormalizeToBufferRegion(args[1]);
node->cRegion_ = NormalizeToBufferRegion(args[2]);
node->a_ = node->aRegion_->buffer;
node->b_ = node->bRegion_->buffer;
......@@ -83,12 +80,15 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) {
node->wgWait_ = args[15].as<IntImm>().value()->value;
}
node->mbarPtr_ = args[16];
if (node->mbarPtr_.as<CallNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)];
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
} else {
node->mbar_ = std::nullopt;
}
}
node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
data_ = std::move(node);
......@@ -500,11 +500,13 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_;
Array<PrimExpr> new_args;
auto mbarPtr =
MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true);
new_args.push_back(StringImm(ss.str()));
new_args.push_back(Aptr);
new_args.push_back(Bptr);
new_args.push_back(BufferLoad(C_buffer, cCoords_));
new_args.push_back(mbarPtr_);
new_args.push_back(mbarPtr);
new_args.push_back(clearAccum_);
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
......
......@@ -97,7 +97,7 @@ public:
// only will be enabled under cdna mfma instructions
int kPack_ = 1;
int wgWait_ = 0;
PrimExpr mbarPtr_;
BufferRegion mbarRegion_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_;
mutable GemmWarpPolicy policy_;
......@@ -144,7 +144,7 @@ private:
class Gemm : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode);
TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL Gemm(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -12,7 +12,6 @@
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h"
#include "utils.h"
......@@ -46,19 +45,17 @@ using namespace tir;
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
GemmPy::GemmPy(Array<PrimExpr> args) {
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->bRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->cRegion_ = NormalizeToBufferRegion(args[2], vmap);
node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->bRegion_ = NormalizeToBufferRegion(args[1]);
node->cRegion_ = NormalizeToBufferRegion(args[2]);
node->a_ = node->aRegion_->buffer;
node->b_ = node->bRegion_->buffer;
......@@ -83,11 +80,12 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) {
node->wgWait_ = args[15].as<IntImm>().value()->value;
}
node->mbarPtr_ = args[16];
if (node->mbarPtr_.as<CallNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)];
} else {
node->mbar_ = std::nullopt;
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
}
}
node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
......
......@@ -29,8 +29,8 @@ public:
int strideA_, strideB_;
int offsetA_, offsetB_;
PrimExpr clearAccum_ = const_false();
PrimExpr mbarPtr_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
BufferRegion mbarRegion_;
tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
......@@ -59,7 +59,8 @@ public:
.def_ro("offsetA", &GemmPyNode::offsetA_)
.def_ro("offsetB", &GemmPyNode::offsetB_)
.def_ro("clearAccum", &GemmPyNode::clearAccum_)
.def_ro("mbarPtr", &GemmPyNode::mbarPtr_)
.def_ro("mbarRegion", &GemmPyNode::mbarRegion_)
.def_ro("mbar", &GemmPyNode::mbar_)
.def_ro("cCoords", &GemmPyNode::cCoords_)
.def_ro("kPack", &GemmPyNode::kPack_)
.def_ro("wgWait", &GemmPyNode::wgWait_)
......@@ -82,7 +83,7 @@ private:
class GemmPy : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode);
TVM_DLL GemmPy(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL GemmPy(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -14,6 +14,7 @@
#include "../target/utils.h"
#include "builtin.h"
#include "gemm.h"
#include "utils.h"
namespace tvm {
namespace tl {
......@@ -79,16 +80,19 @@ std::pair<int, int> GemmSPWarpPolicyNode::computeWarpPartition(int M, int N,
* The populated GemmSPNode is stored in the instance's internal data_ pointer.
*
* @param args Positional TL call arguments in the above order.
* @param vmap BufferMap mapping access pointers (from args) to Buffer objects.
*
* @note An ICHECK failure is raised if a provided kPack is not 1 or 2.
*/
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
GemmSP::GemmSP(Array<PrimExpr> args) {
ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>();
node->a_ = vmap[GetVarFromAccessPtr(args[0])];
node->e_ = vmap[GetVarFromAccessPtr(args[1])];
node->b_ = vmap[GetVarFromAccessPtr(args[2])];
node->c_ = vmap[GetVarFromAccessPtr(args[3])];
node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->eRegion_ = NormalizeToBufferRegion(args[1]);
node->bRegion_ = NormalizeToBufferRegion(args[2]);
node->cRegion_ = NormalizeToBufferRegion(args[3]);
node->a_ = node->aRegion_->buffer;
node->e_ = node->eRegion_->buffer;
node->b_ = node->bRegion_->buffer;
node->c_ = node->cRegion_->buffer;
node->transA_ = args[4].as<Bool>().value();
node->transB_ = args[5].as<Bool>().value();
node->m_ = args[6].as<IntImm>().value()->value;
......
......@@ -53,6 +53,7 @@ public:
class GemmSPNode : public TileOperatorNode {
public:
BufferRegion aRegion_, bRegion_, cRegion_, eRegion_;
tir::Buffer a_, b_, c_, e_;
bool transA_, transB_;
int m_, n_, k_;
......@@ -75,6 +76,10 @@ public:
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPNode>()
.def_ro("policy", &GemmSPNode::policy_)
.def_ro("aRegion", &GemmSPNode::aRegion_)
.def_ro("bRegion", &GemmSPNode::bRegion_)
.def_ro("cRegion", &GemmSPNode::cRegion_)
.def_ro("eRegion", &GemmSPNode::eRegion_)
.def_ro("a", &GemmSPNode::a_)
.def_ro("b", &GemmSPNode::b_)
.def_ro("c", &GemmSPNode::c_)
......@@ -96,7 +101,7 @@ private:
class GemmSP : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode);
TVM_DLL GemmSP(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL GemmSP(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -24,16 +24,14 @@ using namespace tir;
*
* @param call The TIR Call whose operator and arguments will be used to build
* the TileOperator.
* @param vmap Buffer mapping passed through to the builder to resolve buffer
* references.
* @return TileOperator The constructed TileOperator, or a default (empty)
* TileOperator if no builder exists.
*/
TileOperator ParseOperator(Call call, BufferMap vmap) {
TileOperator ParseOperator(Call call) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value();
if (op_map.count(op)) {
auto tile_op = op_map[op](call->args, vmap);
auto tile_op = op_map[op](call->args);
ICHECK(tile_op.defined());
return tile_op;
}
......@@ -48,14 +46,13 @@ TileOperator ParseOperator(Call call, BufferMap vmap) {
* Otherwise returns a default-constructed (empty) TileOperator.
*
* @param stmt TIR statement to inspect; expected to be an Evaluate of a Call.
* @param vmap Mapping of buffer variables used when building the operator.
* @return TileOperator Parsed operator on success, or a default (empty)
* TileOperator if `stmt` is not an Evaluate(Call).
*/
TileOperator ParseOperator(Stmt stmt, BufferMap vmap) {
TileOperator ParseOperator(Stmt stmt) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(tvm::ffi::GetRef<Call>(call), vmap);
return ParseOperator(tvm::ffi::GetRef<Call>(call));
}
return TileOperator();
}
......
......@@ -72,11 +72,10 @@ public:
Var GetVarFromAccessPtr(const PrimExpr &expr);
TileOperator ParseOperator(Call call, BufferMap vmap);
TileOperator ParseOperator(Stmt stmt, BufferMap vmap);
TileOperator ParseOperator(Call call);
TileOperator ParseOperator(Stmt stmt);
using OpBuilderFunc =
ffi::TypedFunction<TileOperator(Array<PrimExpr>, BufferMap)>;
using OpBuilderFunc = ffi::TypedFunction<TileOperator(Array<PrimExpr>)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op &Entry::Get() { \
......@@ -85,10 +84,8 @@ using OpBuilderFunc =
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>("TLOpBuilder", \
[](Array<PrimExpr> args, BufferMap vmap) { \
return Entry(args, vmap); \
})
.set_attr<OpBuilderFunc>( \
"TLOpBuilder", [](Array<PrimExpr> args) { return Entry(args); })
} // namespace tl
} // namespace tvm
......
......@@ -14,7 +14,6 @@
#include "../op/parallel.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "region.h"
#include "tir/transforms/ir_utils.h"
#include "tvm/tir/stmt.h"
#include "utils.h"
......@@ -28,11 +27,11 @@ using namespace tir;
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ReduceOp::ReduceOp(Array<PrimExpr> args) {
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
// Accept BufferRegion/BufferLoad/tl.region for src/dst
node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
// Accept BufferRegion/BufferLoad for src/dst
node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer;
std::string reduce_type = args[2].as<StringImm>().value()->value;
......@@ -494,7 +493,7 @@ static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) {
return BufferRegion(buf, ranges);
}
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
CumSumOp::CumSumOp(Array<PrimExpr> args) {
/// CumSum constructor arguments:
/// - src: input buffer
/// - dst: output buffer
......@@ -504,8 +503,8 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>();
// node->src = vmap[GetVarFromAccessPtr(args[0])];
// node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer;
node->dim = args[2].as<IntImm>().value()->value;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment