Unverified Commit 86c8bb46 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Improve scalar handling in CopyNode and update loop partition dtype logi (#1111)

* [Refactor] Improve scalar handling in CopyNode and update loop partition dtype logic

* Refactored CopyNode::MakeSIMTLoop to handle scalar cases more efficiently by moving the scalar check to the end of the function.
* Updated loop_partition.cc to set a default DataType for thread and vector extents, ensuring compatibility when loop_vars_ is empty.

* lint fix

* remove debug print
parent f14fb111
...@@ -299,10 +299,6 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer, ...@@ -299,10 +299,6 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer,
For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars(); Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.empty(); bool is_scalar = loop_vars.empty();
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial,
BufferStore(dst, BufferLoad(src, {0}), {0}));
}
for (const auto &iv : loop_vars) for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom); analyzer->Bind(iv->var, iv->dom);
...@@ -332,6 +328,9 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -332,6 +328,9 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Stmt body = BufferStore(dst, value, dst_indices); Stmt body = BufferStore(dst, value, dst_indices);
if (dst_predicate.defined()) if (dst_predicate.defined())
body = IfThenElse(dst_predicate, body); body = IfThenElse(dst_predicate, body);
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial, body);
}
for (int i = loop_vars.size() - 1; i >= 0; i--) { for (int i = loop_vars.size() - 1; i >= 0; i--) {
Map<String, ObjectRef> annotations = {}; Map<String, ObjectRef> annotations = {};
if (coalesced_width.defined()) { if (coalesced_width.defined()) {
......
...@@ -189,8 +189,10 @@ public: ...@@ -189,8 +189,10 @@ public:
Fragment Partition(const For &op, int num_thread, int vectorize_size) { Fragment Partition(const For &op, int num_thread, int vectorize_size) {
this->VisitStmt(op); this->VisitStmt(op);
ICHECK(!loop_vars_.empty()); DataType dtype = DataType::Int(32);
DataType dtype = loop_vars_[0]->var.dtype(); if (!loop_vars_.empty()) {
dtype = loop_vars_.back()->var.dtype();
}
PrimExpr flattened = make_const(dtype, 0); PrimExpr flattened = make_const(dtype, 0);
PrimExpr vector_extent = make_const(dtype, vectorize_size); PrimExpr vector_extent = make_const(dtype, vectorize_size);
PrimExpr thread_extent_const = make_const(dtype, num_thread); PrimExpr thread_extent_const = make_const(dtype, num_thread);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment