"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "c996cddf6543a6c87521278fbe794d69dd4fddc7"
Unverified Commit 1b42c87b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Update Fragment Indexing in ParallelOpNode's InferLayout Method (#1359)

This commit refines the Fragment creation process in the InferLayout method of ParallelOpNode. It removes the unnecessary forward_index array and utilizes default fragment indexing for consistency with other operations. Additionally, it binds the thread range to enhance comparability across different operations.
parent c6a19fb2
...@@ -252,17 +252,18 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -252,17 +252,18 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
forward_vars.push_back( forward_vars.push_back(
IterVar(Range(0, s), Var(), IterVarType::kDataPar)); IterVar(Range(0, s), Var(), IterVarType::kDataPar));
} }
Array<PrimExpr> forward_index;
for (const auto &iv : forward_vars) {
forward_index.push_back(iv->var);
}
Var rep; Var rep;
auto rep_iter = auto rep_iter =
IterVar({0, T.thread_bounds->extent}, rep, IterVarType::kDataPar); IterVar({0, T.thread_bounds->extent}, rep, IterVarType::kDataPar);
// Use default fragment indexing (single output dim) to
// stay consistent with other ops (e.g., ReduceOp), and
// bind the thread range for comparability.
const PrimExpr &forward_thread = rep; const PrimExpr &forward_thread = rep;
results.Set(buffer, Fragment(forward_vars, forward_index, auto frag = Fragment(forward_vars, /*forward_index=*/{}, forward_thread,
forward_thread, rep_iter)); rep_iter)
->BindThreadRange(T.thread_bounds);
results.Set(buffer, frag);
} }
} }
return results; return results;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment