"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "3a9012e4354656eec842f736d8995350941a9d24"
Unverified Commit 4370309b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Support Layout/Fragment Reshape (#1241)



* Update layout handling and introduce reshape functionality

- Updated the `LayoutNode` class to include a new `Reshape` method, allowing for dynamic reshaping of layouts based on input shapes.
- Enhanced the `OutputShape` method to provide better handling of cases where the analyzer cannot form an `IntervalSet`, implementing fallback mechanisms to ensure safe extents.
- Refactored the `ReduceOpNode` to utilize `BufferRegion` for improved memory handling during reduction operations.
- Added tests for reshaping functionality and layout transformations to ensure correctness and performance in various scenarios.

* lint fix

* Revert tvm submodule pointer to 1815c3e0b6ec4ead36370bbd1562025d8529017c; keep src unchanged

* Update tvm submodule to commit f0bbd3bf741413c35c389ba5dedd5be206000ad1

* Update tvm submodule to commit f0bbd3bf741413c35c389ba5dedd5be206000ad1

* remove useless prove

* remove comment

---------
Co-authored-by: default avatartilelang-bot <bot@tilelang>
parent 02cfc2a3
Subproject commit 1815c3e0b6ec4ead36370bbd1562025d8529017c Subproject commit 093b2cdb2187140b197336496d65d61ace89e8ff
...@@ -102,10 +102,24 @@ Array<PrimExpr> LayoutNode::OutputShape() const { ...@@ -102,10 +102,24 @@ Array<PrimExpr> LayoutNode::OutputShape() const {
for (size_t i = 0; i < ret.size(); i++) { for (size_t i = 0; i < ret.size(); i++) {
auto ist = analyzer.int_set(forward_index_[i] + 1); auto ist = analyzer.int_set(forward_index_[i] + 1);
if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) { if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) {
// X-OR Expression // Analyzer couldn't form an IntervalSet (e.g. bitwise ops).
ret.Set(i, input_size_[i]); // Fall back to ConstIntBound to derive a safe extent.
auto cib = analyzer.const_int_bound(forward_index_[i]);
if (cib->min_value != arith::ConstIntBound::kNegInf &&
cib->max_value != arith::ConstIntBound::kPosInf &&
cib->min_value >= 0) {
// extent = max - min + 1, using 64-bit integer literal
ret.Set(i, Integer(cib->max_value - cib->min_value + 1));
} else {
// Last-resort conservative fallback to avoid OOB/crash
// Prefer to keep dimension from known input_size_ if available.
if (i < input_size_.size()) {
ret.Set(i, input_size_[i]);
} else {
ret.Set(i, Integer(1));
}
}
} else { } else {
// CHECK(is_one(ist.min())) << ist.min();
ret.Set(i, ist.max()); ret.Set(i, ist.max());
} }
} }
...@@ -282,10 +296,156 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const { ...@@ -282,10 +296,156 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
return {Layout(outputs_shape, backward_index), level}; return {Layout(outputs_shape, backward_index), level};
} }
Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
// Fast path: if shape is the same, return the original layout
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Layout>(this);
}
// Step 1. Prove the product of InputShape is equal to the product of shape
PrimExpr input_shape_product = Integer(1);
for (const auto &dim : InputShape()) {
input_shape_product *= dim;
}
PrimExpr shape_product = Integer(1);
for (const auto &dim : shape) {
shape_product *= dim;
}
if (analyzer) {
ICHECK(analyzer->CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
} else {
arith::Analyzer local_analyzer;
ICHECK(local_analyzer.CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
}
// Step 2. Create new forward indices by reshaping
// For each dimension in the new shape, we create a placeholder variable
Array<Var> new_vars;
for (size_t i = 0; i < shape.size(); ++i) {
new_vars.push_back(InputPlaceholder(i));
}
// Step 3. Compute the flat index from new shape indices
// flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn
PrimExpr flat_index = Integer(0);
for (size_t i = 0; i < shape.size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < shape.size(); ++j) {
stride = stride * shape[j];
}
flat_index = flat_index + new_vars[i] * stride;
}
// Step 4. Convert flat index back to original shape indices
// For original shape [s0, s1, ..., sm]:
// i0 = flat_index // (s1 * s2 * ... * sm)
// i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm)
// ...
Array<PrimExpr> original_indices;
PrimExpr remaining = flat_index;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j) {
stride = stride * InputShape()[j];
}
original_indices.push_back(floordiv(remaining, stride));
remaining = floormod(remaining, stride);
}
// Step 5. Substitute original indices into forward_index_
Array<PrimExpr> new_forward_index;
for (const auto &fwd_expr : forward_index_) {
PrimExpr substituted = fwd_expr;
// Replace each InputPlaceholder(i) with original_indices[i]
for (size_t i = 0; i < InputShape().size(); ++i) {
substituted =
Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}});
}
new_forward_index.push_back(substituted);
}
return Layout(shape, new_forward_index);
}
Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
// Fast path: identical input shape, return self
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Fragment>(this);
}
// 1) Prove total number of elements remains the same
PrimExpr input_prod = Integer(1);
for (const auto &d : InputShape())
input_prod *= d;
PrimExpr shape_prod = Integer(1);
for (const auto &d : shape)
shape_prod *= d;
if (analyzer) {
ICHECK(analyzer->CanProveEqual(input_prod, shape_prod))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< " input fragment layout is = " << DebugOutput();
} else {
arith::Analyzer local_analyzer;
ICHECK(local_analyzer.CanProveEqual(input_prod, shape_prod))
<< "InputShape() = " << InputShape() << " shape = " << shape;
}
// 2) Build flat index from new-shape indices
Array<Var> new_vars;
new_vars.reserve(shape.size());
for (size_t i = 0; i < shape.size(); ++i)
new_vars.push_back(InputPlaceholder(i));
PrimExpr flat = Integer(0);
for (size_t i = 0; i < shape.size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < shape.size(); ++j)
stride = stride * shape[j];
flat = flat + new_vars[i] * stride;
}
// 3) Recover original indices from flat index
Array<PrimExpr> orig_indices;
PrimExpr remain = flat;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j)
stride = stride * InputShape()[j];
orig_indices.push_back(floordiv(remain, stride));
remain = floormod(remain, stride);
}
// 4) Substitute old placeholders with expressions of new indices
Array<PrimExpr> new_forward_index;
for (const auto &e : forward_index_) {
PrimExpr cur = e;
for (size_t i = 0; i < InputShape().size(); ++i) {
cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}});
}
new_forward_index.push_back(cur);
}
PrimExpr new_forward_thread = forward_thread_;
for (size_t i = 0; i < InputShape().size(); ++i) {
new_forward_thread = Substitute(new_forward_thread,
{{InputPlaceholder(i), orig_indices[i]}});
}
Fragment reshaped(shape, new_forward_index, new_forward_thread,
ReplicateExtent(), std::nullopt);
if (thread_range_.defined()) {
reshaped = reshaped->BindThreadRange(thread_range_);
}
return reshaped;
}
Layout LayoutNode::Inverse() const { Layout LayoutNode::Inverse() const {
auto inverse_result = InverseWithLevel(); auto inverse_result = InverseWithLevel();
return std::move(inverse_result.first); return std::move(inverse_result.first);
} }
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters, PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
const PrimExpr &forward_thread, const PrimExpr &forward_thread,
arith::Analyzer *analyzer) { arith::Analyzer *analyzer) {
......
...@@ -41,6 +41,10 @@ public: ...@@ -41,6 +41,10 @@ public:
virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const; virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;
virtual Layout Inverse() const; virtual Layout Inverse() const;
virtual Layout Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const;
virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const; virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;
virtual std::string DebugOutput() const; virtual std::string DebugOutput() const;
...@@ -81,6 +85,9 @@ public: ...@@ -81,6 +85,9 @@ public:
Array<PrimExpr> GetForwardVars() const final; Array<PrimExpr> GetForwardVars() const final;
Layout Inverse() const final; Layout Inverse() const final;
Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer) const;
std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final; std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;
PrimExpr ThreadExtent() const; PrimExpr ThreadExtent() const;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "../op/parallel.h" #include "../op/parallel.h"
#include "../target/utils.h" #include "../target/utils.h"
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "region.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
namespace tvm { namespace tvm {
...@@ -21,10 +22,54 @@ namespace tl { ...@@ -21,10 +22,54 @@ namespace tl {
using namespace tir; using namespace tir;
// Normalize an argument (BufferRegion/BufferLoad/tl.region)
// to BufferRegion so Reduce can uniformly consume regions.
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes (only tl.region)
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
}
LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg;
throw; // Unreachable
}
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>(); ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])]; // Accept BufferRegion/BufferLoad/tl.region for src/dst
node->dst = vmap[GetVarFromAccessPtr(args[1])]; node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer;
std::string reduce_type = args[2].as<StringImm>().value()->value; std::string reduce_type = args[2].as<StringImm>().value()->value;
node->dim = args[3].as<IntImm>().value()->value; node->dim = args[3].as<IntImm>().value()->value;
node->type = ReduceType(reduce_type); node->type = ReduceType(reduce_type);
...@@ -369,6 +414,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -369,6 +414,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const { InferLevel level) const {
if (level >= InferLevel::kStrict) if (level >= InferLevel::kStrict)
return {}; return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
T.layout_map.count(src)) { T.layout_map.count(src)) {
auto src_layout = T.layout_map[src].as<Fragment>().value(); auto src_layout = T.layout_map[src].as<Fragment>().value();
...@@ -422,6 +468,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -422,6 +468,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt) Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
->CondenseReplicateVar() ->CondenseReplicateVar()
->BindThreadRange(T.thread_bounds); ->BindThreadRange(T.thread_bounds);
if (!T.layout_map.count(dst)) if (!T.layout_map.count(dst))
return {{dst, dst_layout}}; return {{dst, dst_layout}};
else { else {
......
...@@ -82,9 +82,11 @@ public: ...@@ -82,9 +82,11 @@ public:
class ReduceOpNode : public TileOperatorNode { class ReduceOpNode : public TileOperatorNode {
public: public:
tir::Buffer src, dst; ///< Source and destination buffers tir::Buffer src, dst; ///< Source and destination buffers
int dim; ///< Dimension to reduce along // Optional: keep the original regions used to construct this op
ReduceType type; ///< Type of reduction operation BufferRegion srcRegion_, dstRegion_;
bool clear; ///< Whether to clear destination before reduction int dim; ///< Dimension to reduce along
ReduceType type; ///< Type of reduction operation
bool clear; ///< Whether to clear destination before reduction
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceOp", ReduceOpNode, TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceOp", ReduceOpNode,
TileOperatorNode); TileOperatorNode);
...@@ -94,6 +96,8 @@ public: ...@@ -94,6 +96,8 @@ public:
refl::ObjectDef<ReduceOpNode>() refl::ObjectDef<ReduceOpNode>()
.def_ro("src", &ReduceOpNode::src) .def_ro("src", &ReduceOpNode::src)
.def_ro("dst", &ReduceOpNode::dst) .def_ro("dst", &ReduceOpNode::dst)
.def_ro("srcRegion", &ReduceOpNode::srcRegion_)
.def_ro("dstRegion", &ReduceOpNode::dstRegion_)
.def_ro("dim", &ReduceOpNode::dim) .def_ro("dim", &ReduceOpNode::dim)
.def_ro("type", &ReduceOpNode::type) .def_ro("type", &ReduceOpNode::type)
.def_ro("clear", &ReduceOpNode::clear); .def_ro("clear", &ReduceOpNode::clear);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <tvm/tir/utils.h> #include <tvm/tir/utils.h>
#include <algorithm>
#include <queue> #include <queue>
#include "../layout/utils.h" #include "../layout/utils.h"
...@@ -105,20 +106,60 @@ public: ...@@ -105,20 +106,60 @@ public:
"required for layout inference."; "required for layout inference.";
// Run InferLayout // Run InferLayout
DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n';
auto updates = auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob}, &analyzer_, buffer_oob},
level); level);
// Process the returned updates // Process the returned updates
for (const auto &[buffer, layout] : updates) { for (const auto &[buffer, layout] : updates) {
DLOG(INFO) << " consider update " << buffer << " as "
<< layout->DebugOutput() << '\n';
// Basic validity checks // Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer."; ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout."; ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
// Helper: propagate inferred layout to alias buffers (same data Var)
auto propagate_alias = [&](const Buffer &src_buffer,
const Layout &src_layout) {
if (!buffer_data_to_buffers_.count(src_buffer->data))
return;
const auto &siblings = buffer_data_to_buffers_[src_buffer->data];
for (const auto &sib : siblings) {
if (sib.same_as(src_buffer))
continue;
bool shapes_equal =
src_layout->InputShape().size() == sib->shape.size();
if (shapes_equal) {
for (size_t i = 0; i < src_layout->InputShape().size(); ++i) {
if (!analyzer_.CanProveEqual(src_layout->InputShape()[i],
sib->shape[i])) {
shapes_equal = false;
break;
}
}
}
Layout target_layout =
shapes_equal ? src_layout
: src_layout->Reshape(sib->shape, &analyzer_);
if (layout_map.count(sib)) {
ICHECK(target_layout->IsEqual(layout_map[sib].get()))
<< "Get different layout for alias buffer " << sib
<< " (data-shared with " << src_buffer
<< ")\n current: " << target_layout->DebugOutput()
<< "\n previous: " << layout_map[sib]->DebugOutput();
} else {
layout_map.Set(sib, target_layout);
if (update_queue && use_list_.count(sib)) {
for (int idx : use_list_[sib]) {
if (!in_queue[idx] && idx != cur_infer_id) {
in_queue[idx] = true;
q.push(idx);
}
}
}
}
}
};
if (layout_map.count(buffer)) { if (layout_map.count(buffer)) {
// If new layout contains the old one, update map // If new layout contains the old one, update map
if (buffer.scope() == "local.fragment" && if (buffer.scope() == "local.fragment" &&
...@@ -153,8 +194,8 @@ public: ...@@ -153,8 +194,8 @@ public:
if (ProveFragmentContains(src_layout, dst_layout, indices, indices, if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
inner_analyzer)) { inner_analyzer)) {
layout_map.Set(buffer, layout); layout_map.Set(buffer, layout);
DLOG(INFO) << " layout broadcast from " // Propagate to alias buffers as well
<< src_layout->DebugOutput() << ", accepted" << '\n'; propagate_alias(buffer, layout);
continue; continue;
} }
} }
...@@ -163,10 +204,13 @@ public: ...@@ -163,10 +204,13 @@ public:
<< "Get different layout for " << buffer << "Get different layout for " << buffer
<< "\n current layout: " << layout->DebugOutput() << "\n current layout: " << layout->DebugOutput()
<< "\n previous layout: " << layout_map[buffer]->DebugOutput(); << "\n previous layout: " << layout_map[buffer]->DebugOutput();
// Ensure aliases are consistent too
propagate_alias(buffer, layout);
} else { } else {
// Otherwise, update map // Otherwise, update map
layout_map.Set(buffer, layout); layout_map.Set(buffer, layout);
DLOG(INFO) << " new layout accepted" << '\n'; // Propagate to alias buffers (may enqueue their users)
propagate_alias(buffer, layout);
if (!update_queue) if (!update_queue)
continue; continue;
...@@ -272,6 +316,46 @@ public: ...@@ -272,6 +316,46 @@ public:
// step 3: relax constraints to free and re-run // step 3: relax constraints to free and re-run
InferInFreeMode(layout_map, strict_layout_map); InferInFreeMode(layout_map, strict_layout_map);
// step 4: finalize alias layouts by Var
// For each storage var, if any buffer in the group has a layout,
// propagate (reshape if needed) to the rest to ensure completeness.
for (const auto &[var, buffers] : buffer_data_to_buffers_) {
// Find a representative with existing layout
Optional<Buffer> rep;
Optional<Layout> rep_layout;
for (const auto &buf : buffers) {
if (layout_map.count(buf)) {
rep = buf;
rep_layout = layout_map[buf];
break;
}
}
if (!rep_layout.defined())
continue;
for (const auto &buf : buffers) {
if (!layout_map.count(buf)) {
bool shapes_equal =
rep_layout.value()->InputShape().size() == buf->shape.size();
if (shapes_equal) {
for (size_t i = 0; i < rep_layout.value()->InputShape().size();
++i) {
if (!analyzer_.CanProveEqual(rep_layout.value()->InputShape()[i],
buf->shape[i])) {
shapes_equal = false;
break;
}
}
}
Layout reshaped =
shapes_equal
? rep_layout.value()
: rep_layout.value()->Reshape(buf->shape, &analyzer_);
layout_map.Set(buf, reshaped);
}
}
}
// Check that all local.fragment buffers have inferred layouts // Check that all local.fragment buffers have inferred layouts
for (const auto &[buffer, _] : use_list_) { for (const auto &[buffer, _] : use_list_) {
if (buffer.scope() == "local.fragment") { if (buffer.scope() == "local.fragment") {
...@@ -314,7 +398,13 @@ public: ...@@ -314,7 +398,13 @@ public:
void Collect(const PrimFunc &f) { void Collect(const PrimFunc &f) {
for (const auto &[_, buffer] : f->buffer_map) { for (const auto &[_, buffer] : f->buffer_map) {
buffer_data_to_buffer_.Set(buffer->data, buffer); if (buffer_data_to_buffers_.count(buffer->data)) {
auto buffers = buffer_data_to_buffers_[buffer->data];
buffers.push_back(buffer);
buffer_data_to_buffers_.Set(buffer->data, buffers);
} else {
buffer_data_to_buffers_.Set(buffer->data, {buffer});
}
} }
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) ICHECK(target.defined())
...@@ -324,13 +414,25 @@ public: ...@@ -324,13 +414,25 @@ public:
} }
private: private:
Map<Var, Buffer> GetBufferMap() const {
Map<Var, Buffer> buffer_map;
for (const auto &[var, buffers] : buffer_data_to_buffers_) {
// Use the first buffer for each var
// TODO(lei): phaseout buffer_map in future.
if (!buffers.empty()) {
buffer_map.Set(var, buffers[0]);
}
}
return buffer_map;
}
void VisitExpr_(const CallNode *op) final { void VisitExpr_(const CallNode *op) final {
IRVisitorWithAnalyzer::VisitExpr_(op); IRVisitorWithAnalyzer::VisitExpr_(op);
// Do not analysis the call node to the global function. // Do not analysis the call node to the global function.
if (op->op.as<GlobalVarNode>()) if (op->op.as<GlobalVarNode>())
return; return;
auto p = ParseOperator(tvm::ffi::GetRef<Call>(op), buffer_data_to_buffer_); auto p = ParseOperator(tvm::ffi::GetRef<Call>(op), GetBufferMap());
if (p.defined()) { if (p.defined()) {
for (const auto &arg : op->args) { for (const auto &arg : op->args) {
if (auto buffer = getBufferFromAccessPtr(arg)) { if (auto buffer = getBufferFromAccessPtr(arg)) {
...@@ -394,12 +496,18 @@ private: ...@@ -394,12 +496,18 @@ private:
if (call->op.same_as(builtin::tvm_access_ptr())) { if (call->op.same_as(builtin::tvm_access_ptr())) {
auto var_opt = call->args[1].as<Var>(); auto var_opt = call->args[1].as<Var>();
if (!var_opt.has_value()) { if (!var_opt.has_value()) {
DLOG(WARNING) << "[getBufferFromAccessPtr] args[1] is not a Var, type: " LOG(WARNING) << "[getBufferFromAccessPtr] args[1] is not a Var, type: "
<< call->args[1]->GetTypeKey(); << call->args[1]->GetTypeKey();
return std::nullopt; return std::nullopt;
} }
const auto &var = var_opt.value(); const auto &var = var_opt.value();
return buffer_data_to_buffer_[var]; if (buffer_data_to_buffers_.count(var)) {
const auto &buffers = buffer_data_to_buffers_[var];
if (!buffers.empty()) {
return buffers[0]; // Return the first buffer
}
}
return std::nullopt;
} else if (call->op.same_as(RegionOp::Get())) { } else if (call->op.same_as(RegionOp::Get())) {
return call->args[0].as<BufferLoadNode>()->buffer; return call->args[0].as<BufferLoadNode>()->buffer;
} }
...@@ -442,21 +550,55 @@ private: ...@@ -442,21 +550,55 @@ private:
void VisitStmt_(const BlockNode *op) final { void VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) { for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer); if (buffer_data_to_buffers_.count(buffer->data)) {
auto buffers = buffer_data_to_buffers_[buffer->data];
buffers.push_back(buffer);
buffer_data_to_buffers_.Set(buffer->data, buffers);
} else {
buffer_data_to_buffers_.Set(buffer->data, {buffer});
}
} }
// First, visit the block body to collect all buffers from
// BufferLoad/BufferStore
IRVisitorWithAnalyzer::VisitStmt_(op);
// After visiting, apply layouts to all collected buffers
if (op->annotations.count(attr::kLayoutMap)) { if (op->annotations.count(attr::kLayoutMap)) {
// Check if the layout map is Map<Var, Layout> // Check if the layout map is Map<Var, Layout>
auto map = auto map =
op->annotations.Get(attr::kLayoutMap)->as<Map<Var, Layout>>().value(); op->annotations.Get(attr::kLayoutMap)->as<Map<Var, Layout>>().value();
for (const auto &[var, layout] : map) { for (const auto &[var, layout] : map) {
ICHECK(buffer_data_to_buffer_.count(var)) ICHECK(buffer_data_to_buffers_.count(var))
<< "buffer " << var << " is not found in the block"; << "buffer " << var << " is not found in the block";
auto buffer = buffer_data_to_buffer_[var]; const auto &buffers = buffer_data_to_buffers_[var];
ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape)); ICHECK(!buffers.empty()) << "buffer list for " << var << " is empty";
annotated_layout_map_.Set(buffer, layout); // Apply layout to all buffers associated with this var
for (const auto &buffer : buffers) {
// Reshape the layout to match the buffer's shape
// Check if shapes are structurally equal
bool shapes_equal =
layout->InputShape().size() == buffer->shape.size();
if (shapes_equal) {
for (size_t i = 0; i < layout->InputShape().size(); ++i) {
if (!analyzer_.CanProveEqual(layout->InputShape()[i],
buffer->shape[i])) {
shapes_equal = false;
break;
}
}
}
if (shapes_equal) {
annotated_layout_map_.Set(buffer, layout);
} else {
auto reshaped_layout = layout->Reshape(buffer->shape, &analyzer_);
annotated_layout_map_.Set(buffer, reshaped_layout);
}
}
} }
} }
IRVisitorWithAnalyzer::VisitStmt_(op);
} }
void VisitStmt_(const AttrStmtNode *op) final { void VisitStmt_(const AttrStmtNode *op) final {
...@@ -470,7 +612,67 @@ private: ...@@ -470,7 +612,67 @@ private:
IRVisitorWithAnalyzer::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
} }
Map<Var, Buffer> buffer_data_to_buffer_; void VisitExpr_(const BufferLoadNode *op) final {
// Collect buffer from BufferLoad
if (op->buffer.defined() && op->buffer->data.defined()) {
if (buffer_data_to_buffers_.count(op->buffer->data)) {
// Check if this buffer is already in the list
auto buffers = buffer_data_to_buffers_[op->buffer->data];
bool found = false;
for (const auto &buf : buffers) {
if (buf.same_as(op->buffer)) {
found = true;
break;
}
}
if (!found) {
buffers.push_back(op->buffer);
buffer_data_to_buffers_.Set(op->buffer->data, buffers);
DLOG(INFO) << "[LayoutInference] BufferLoad: added buffer "
<< op->buffer << " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
} else {
buffer_data_to_buffers_.Set(op->buffer->data, {op->buffer});
DLOG(INFO) << "[LayoutInference] BufferLoad: new buffer " << op->buffer
<< " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
// Collect buffer from BufferStore
if (op->buffer.defined() && op->buffer->data.defined()) {
if (buffer_data_to_buffers_.count(op->buffer->data)) {
// Check if this buffer is already in the list
auto buffers = buffer_data_to_buffers_[op->buffer->data];
bool found = false;
for (const auto &buf : buffers) {
if (buf.same_as(op->buffer)) {
found = true;
break;
}
}
if (!found) {
buffers.push_back(op->buffer);
buffer_data_to_buffers_.Set(op->buffer->data, buffers);
DLOG(INFO) << "[LayoutInference] BufferStore: added buffer "
<< op->buffer << " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
} else {
buffer_data_to_buffers_.Set(op->buffer->data, {op->buffer});
DLOG(INFO) << "[LayoutInference] BufferStore: new buffer " << op->buffer
<< " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
Map<Var, Array<Buffer>> buffer_data_to_buffers_;
std::vector<ObjectRef> infer_list_stmt_; std::vector<ObjectRef> infer_list_stmt_;
std::vector<TileOperator> infer_list_; std::vector<TileOperator> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual> std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
...@@ -513,12 +715,33 @@ private: ...@@ -513,12 +715,33 @@ private:
if (infer_indices.empty()) if (infer_indices.empty())
continue; continue;
// Union all infer_list_ indices that share the same buffer // Union all infer_list_ indices that share the same Buffer object
int first_idx = infer_indices[0]; int first_idx = infer_indices[0];
for (size_t i = 1; i < infer_indices.size(); i++) { for (size_t i = 1; i < infer_indices.size(); i++) {
uf.Union(first_idx, infer_indices[i]); uf.Union(first_idx, infer_indices[i]);
} }
} }
// Additionally, union across buffers that share the same underlying
// buffer->data (Var). This handles cases like reshape where multiple
// Buffer objects alias the same storage.
for (const auto &[var, buffers] : buffer_data_to_buffers_) {
std::vector<int> merged;
for (const auto &buf : buffers) {
auto it = use_list_.find(buf);
if (it != use_list_.end()) {
const auto &vec = it->second;
merged.insert(merged.end(), vec.begin(), vec.end());
}
}
if (merged.size() > 1) {
std::sort(merged.begin(), merged.end());
merged.erase(std::unique(merged.begin(), merged.end()), merged.end());
int first = merged[0];
for (size_t i = 1; i < merged.size(); ++i) {
uf.Union(first, merged[i]);
}
}
}
std::unordered_map<int, std::vector<int>> components; std::unordered_map<int, std::vector<int>> components;
for (int i = 0; i < infer_list_.size(); i++) { for (int i = 0; i < infer_list_.size(); i++) {
int root = uf.Find(i); int root = uf.Find(i);
......
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl import tilelang as tl
import torch
def reshape_test(N, M, dtype): def reshape_test(N, M, dtype):
...@@ -129,5 +130,137 @@ def test_reshape_smem_2d_2_1d(): ...@@ -129,5 +130,137 @@ def test_reshape_smem_2d_2_1d():
run_reshape_smem_2d_2_1d(2048, 64, "float16") run_reshape_smem_2d_2_1d(2048, 64, "float16")
def reshape_fragment_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
A_local = T.alloc_fragment((N // M, M), dtype)
B_shared = T.alloc_shared((N,), dtype, scope="shared")
T.copy(A, A_shared)
T.copy(A_shared, A_local)
A_local_reshape = T.reshape(A_local, [N])
T.copy(A_local_reshape, B_shared)
T.copy(B_shared, B)
return main
def run_reshape_fragment(N, M, dtype):
program = reshape_fragment_test(N, M, dtype)
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.reshape(N)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reshape_fragment():
run_reshape_fragment(1024, 32, "float32")
run_reshape_fragment(2048, 64, "float16")
def reshape_layout_transform_shared(N, M, dtype):
import tilelang.language as T
from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout
@T.prim_func
def main(
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
T.annotate_layout({
A_shared: make_mma_swizzle_layout(A_shared),
})
T.copy(A, A_shared)
A_shared_reshape = T.reshape(A_shared, [N])
T.copy(A_shared_reshape, B)
return main
def run_reshape_layout_transform_shared(N, M, dtype):
program = reshape_layout_transform_shared(N, M, dtype)
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.reshape(N)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reshape_layout_transform_shared():
run_reshape_layout_transform_shared(1024, 32, "float32")
run_reshape_layout_transform_shared(2048, 64, "float16")
def reduce_after_reshape_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N,), dtype, scope="shared")
A_local = T.alloc_fragment((N,), dtype)
B_local = T.alloc_fragment((N // M,), dtype)
T.copy(A, A_shared)
T.copy(A_shared, A_local)
A_local_reshape = T.reshape(A_local, [N // M, M])
T.reduce_max(A_local_reshape, B_local, dim=1)
T.copy(B_local, B)
return main
def run_reduce_after_reshape(N, M, dtype):
program = reduce_after_reshape_test(N, M, dtype)
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler()
def ref_program(A):
return torch.max(A.reshape(N // M, M), dim=1).values
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reduce_after_reshape():
run_reduce_after_reshape(1024, 32, "float32")
run_reduce_after_reshape(2048, 64, "float16")
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
from tvm import tir from tvm import tir
from tilelang.language import copy, macro, alloc_shared, alloc_fragment from tilelang.language import copy, macro, alloc_shared, alloc_fragment
from tilelang.language.utils import buffer_to_tile_region
from tilelang.utils.language import is_shared, is_fragment from tilelang.utils.language import is_shared, is_fragment
from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import IRBuilder
...@@ -51,8 +52,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea ...@@ -51,8 +52,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
tir.call_intrin( tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.reduce"), tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"), buffer_to_tile_region(red_frag_in, "r"),
red_frag_out.access_ptr("w"), buffer_to_tile_region(red_frag_out, "w"),
reduce_type, reduce_type,
dim, dim,
clear, clear,
...@@ -66,8 +67,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea ...@@ -66,8 +67,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
tir.call_intrin( tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.reduce"), tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"), buffer_to_tile_region(red_frag_in, "r"),
out.access_ptr("w"), buffer_to_tile_region(out, "w"),
reduce_type, reduce_type,
dim, dim,
clear, clear,
...@@ -79,8 +80,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea ...@@ -79,8 +80,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
tir.call_intrin( tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.reduce"), tir.op.Op.get("tl.reduce"),
buffer.access_ptr("r"), buffer_to_tile_region(buffer, "r"),
red_frag_out.access_ptr("w"), buffer_to_tile_region(red_frag_out, "w"),
reduce_type, reduce_type,
dim, dim,
clear, clear,
...@@ -90,8 +91,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea ...@@ -90,8 +91,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
tir.call_intrin( tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.reduce"), tir.op.Op.get("tl.reduce"),
buffer.access_ptr("r"), buffer_to_tile_region(buffer, "r"),
out.access_ptr("w"), buffer_to_tile_region(out, "w"),
reduce_type, reduce_type,
dim, dim,
clear, clear,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment