Unverified Commit 5bd3f942 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Enhancement] Add role assignment for AllocateNode in warp specialization (#657)

- Implemented a new role assignment for `AllocateNode` in `warp_specialized_rewriter.cc`, setting the role to `kConsumer` to ensure proper handling of memory allocation scenarios.
- This can avoid bug when using T.reduce(clear=False)
parent 8205791d
...@@ -170,6 +170,12 @@ public: ...@@ -170,6 +170,12 @@ public:
SetRole(op, GetRole(op->block)); SetRole(op, GetRole(op->block));
} }
void VisitStmt_(const AllocateNode *op) final {
StmtVisitor::VisitStmt_(op);
Role role = Role::kConsumer;
SetRole(op, role);
}
template <class NodeType> void HandleBodyStmt(const NodeType *op) { template <class NodeType> void HandleBodyStmt(const NodeType *op) {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->body)); SetRole(op, GetRole(op->body));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment