Unverified Commit 86c8bb46 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Improve scalar handling in CopyNode and update loop partition dtype logi (#1111)

* [Refactor] Improve scalar handling in CopyNode and update loop partition dtype logic

* Refactored CopyNode::MakeSIMTLoop to handle scalar cases more efficiently by moving the scalar check to the end of the function.
* Updated loop_partition.cc to set a default DataType for thread and vector extents, ensuring compatibility when loop_vars_ is empty.

* lint fix

* remove debug print
parent f14fb111
......@@ -299,10 +299,6 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer,
For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.empty();
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial,
BufferStore(dst, BufferLoad(src, {0}), {0}));
}
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
......@@ -332,6 +328,9 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Stmt body = BufferStore(dst, value, dst_indices);
if (dst_predicate.defined())
body = IfThenElse(dst_predicate, body);
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial, body);
}
for (int i = loop_vars.size() - 1; i >= 0; i--) {
Map<String, ObjectRef> annotations = {};
if (coalesced_width.defined()) {
......@@ -1979,4 +1978,4 @@ TVM_FFI_STATIC_INIT_BLOCK({
Conv2DIm2ColOpNode::RegisterReflection();
});
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tvm
......@@ -189,8 +189,10 @@ public:
Fragment Partition(const For &op, int num_thread, int vectorize_size) {
this->VisitStmt(op);
ICHECK(!loop_vars_.empty());
DataType dtype = loop_vars_[0]->var.dtype();
DataType dtype = DataType::Int(32);
if (!loop_vars_.empty()) {
dtype = loop_vars_.back()->var.dtype();
}
PrimExpr flattened = make_const(dtype, 0);
PrimExpr vector_extent = make_const(dtype, vectorize_size);
PrimExpr thread_extent_const = make_const(dtype, num_thread);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment