Unverified Commit 0788feb8 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Improve buffer usage tracking in MakePackedAPI (#1435)

* Added detailed logging for data and shape variable parameters during buffer usage detection in the MakePackedAPI function.
* Refactored the UsedBufferDetector to differentiate between used parameters by data and shape variables, enhancing clarity in buffer management.
* Updated logic to ensure minimal carrier buffers are selected for shape symbols, improving the efficiency of parameter handling.
parent fba12a5f
......@@ -324,8 +324,11 @@ PrimFunc MakePackedAPI(PrimFunc func) {
record_shape_vars(buf->elem_offset);
}
// A visitor that marks a buffer as used when its underlying data var is
// referenced (e.g. BufferLoad/BufferStore or any direct var usage).
// A visitor that records
// - which parameter buffers are used via their data var (load/store/direct),
// - which shape/stride/offset symbols are referenced in the body.
// Shape symbols are not immediately attributed to all carrier buffers here;
// a minimal carrier set is selected after visiting.
struct UsedBufferDetector : public StmtExprVisitor {
UsedBufferDetector(
const std::unordered_map<const VarNode *, const VarNode *> &data2param,
......@@ -335,26 +338,25 @@ PrimFunc MakePackedAPI(PrimFunc func) {
void VisitExpr_(const VarNode *op) override {
auto it = data2param.find(op);
if (it != data2param.end()) {
used_params.insert(it->second);
used_params_by_data.insert(it->second);
}
auto it2 = shape2params.find(op);
if (it2 != shape2params.end()) {
for (const VarNode *p : it2->second)
used_params.insert(p);
used_shape_vars.insert(op);
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) override {
auto it = data2param.find(op->buffer->data.get());
if (it != data2param.end()) {
used_params.insert(it->second);
used_params_by_data.insert(it->second);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode *op) override {
auto it = data2param.find(op->buffer->data.get());
if (it != data2param.end()) {
used_params.insert(it->second);
used_params_by_data.insert(it->second);
}
StmtExprVisitor::VisitExpr_(op);
}
......@@ -362,7 +364,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
const std::unordered_map<const VarNode *, const VarNode *> &data2param;
const std::unordered_map<const VarNode *, std::vector<const VarNode *>>
&shape2params;
std::unordered_set<const VarNode *> used_params;
std::unordered_set<const VarNode *> used_params_by_data;
std::unordered_set<const VarNode *> used_shape_vars;
};
UsedBufferDetector detector(data_var2param, shape_var2params);
......@@ -371,7 +374,30 @@ PrimFunc MakePackedAPI(PrimFunc func) {
// Build the packed argument handling. While doing so, keep track of whether
// each parameter buffer is actually used. Unused input buffers can be
// nullable and do not require DLTensor field dereferences.
std::unordered_set<const VarNode *> used_param_buffers = detector.used_params;
//
// Start from buffers used via data-var (definitely non-NULL), then for each
// referenced shape symbol pick a minimal "carrier" buffer that provides the
// symbol. Prefer carriers that are already used-by-data; otherwise pick one
// arbitrary carrier to ensure the symbol is bound.
std::unordered_set<const VarNode *> used_param_buffers =
detector.used_params_by_data;
for (const VarNode *sym : detector.used_shape_vars) {
auto it = shape_var2params.find(sym);
if (it == shape_var2params.end())
continue;
const auto &carriers = it->second;
bool has_used_carrier = false;
for (const VarNode *p : carriers) {
if (used_param_buffers.count(p)) {
has_used_carrier = true;
break;
}
}
if (!has_used_carrier && !carriers.empty()) {
// Choose the first carrier to anchor this symbol.
used_param_buffers.insert(carriers.front());
}
}
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment