"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "6e643b00a22ae251ca120936fda34420f8d88fc5"
Unverified Commit 2f34840f authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Moving `NormalizeToBufferRegion` and `MakeAccessPtrFromRegion` to utils (#1333)

* Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse.

* lint fix
parent 71b73e18
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "../target/utils.h" #include "../target/utils.h"
#include "region.h" #include "region.h"
#include "tcgen5_meta.h" #include "tcgen5_meta.h"
#include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -48,92 +49,9 @@ using namespace tir; ...@@ -48,92 +49,9 @@ using namespace tir;
* fails with an ICHECK (runtime assertion). No other validation is * fails with an ICHECK (runtime assertion). No other validation is
* performed here. * performed here.
*/ */
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region) // NormalizeToBufferRegion moved to src/op/utils.{h,cc}
// to BufferRegion
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in GEMM region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in GEMM region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap[var];
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg;
throw; // Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
PrimExpr extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>(); ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
...@@ -535,9 +453,12 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -535,9 +453,12 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
// Build access pointers from regions locally // Build access pointers from regions locally
PrimExpr Aptr = MakeAccessPtrFromRegion(aRegion_, /*r*/ 1); PrimExpr Aptr =
PrimExpr Bptr = MakeAccessPtrFromRegion(bRegion_, /*r*/ 1); MakeAccessPtrFromRegion(aRegion_, /*r*/ 1, /*require_2d*/ true);
PrimExpr Cptr = MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3); PrimExpr Bptr =
MakeAccessPtrFromRegion(bRegion_, /*r*/ 1, /*require_2d*/ true);
PrimExpr Cptr =
MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3, /*require_2d*/ true);
std::stringstream ss; std::stringstream ss;
std::string op_name; std::string op_name;
......
...@@ -14,98 +14,16 @@ ...@@ -14,98 +14,16 @@
#include "../target/utils.h" #include "../target/utils.h"
#include "region.h" #include "region.h"
#include "tcgen5_meta.h" #include "tcgen5_meta.h"
#include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region) // NormalizeToBufferRegion moved to src/op/utils.{h,cc}
// to BufferRegion
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in GEMM region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in GEMM region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap.at(var);
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg; // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
throw; // Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
PrimExpr extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
/** /**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer * @brief Construct a Gemm operator from serialized TL arguments and a buffer
......
...@@ -17,105 +17,16 @@ ...@@ -17,105 +17,16 @@
#include "region.h" #include "region.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
#include "tvm/tir/stmt.h" #include "tvm/tir/stmt.h"
#include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
// Normalize an argument (BufferRegion/BufferLoad/tl.region) // NormalizeToBufferRegion moved to src/op/utils.{h,cc}
// to BufferRegion so Reduce can uniformly consume regions.
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes (only tl.region)
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap[var];
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg;
throw; // Unreachable
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim == 1 || ndim == 2) << "Cumsum expects buffers with 1 or 2 dims";
PrimExpr offset, extent;
if (ndim == 1) {
// Simple 1D region: offset and extent come from the single axis.
auto axis = region->region[0];
offset = axis->min;
extent = axis->extent;
} else {
// Compute row-major strides for ndim >= 2
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements) // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
}
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>(); ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
......
/*!
* \file tl/op/utils.cc
* \brief Common utilities implementation for TL ops.
*/
#include "utils.h"
#include <tvm/tir/builtin.h>
namespace tvm {
namespace tl {
using namespace tir;
BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap.at(var);
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
LOG(FATAL) << "Unsupported argument for BufferRegion: " << arg;
throw; // Unreachable
}
PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region, int rw_mask,
bool require_2d) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
if (require_2d) {
ICHECK(ndim >= 2) << "Expect buffers with at least 2 dims";
}
PrimExpr offset, extent;
if (ndim == 1) {
// 1D: straightforward
auto axis = region->region[0];
offset = axis->min;
extent = axis->extent;
} else {
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
}
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
} // namespace tl
} // namespace tvm
/*!
* \file tl/op/utils.h
* \brief Common utilities for TL ops.
*/
#ifndef TVM_TL_OP_UTILS_H_
#define TVM_TL_OP_UTILS_H_
#include "./operator.h"
#include "region.h"
#include <tvm/tir/buffer.h>
#include <tvm/tir/op.h>
namespace tvm {
namespace tl {
using namespace tir;
// Normalize an argument (BufferRegion/BufferLoad/tl.region/tvm_access_ptr)
// to BufferRegion so ops can uniformly consume regions.
TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap);
// Build a tvm_access_ptr(handle) from a BufferRegion.
// - If `require_2d` is true, checks buffer ndim >= 2.
// - For 1D regions (when allowed), offset=min, extent=extent.
// - For ndim >= 2, offset sums all but last two dims using row-major strides,
// extent is product of the last two extents.
TVM_DLL PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask, bool require_2d = false);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_UTILS_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment