Unverified Commit 9cbbbbc6 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Layout] Support layout forward with multi dimension (#867)

* Enhance LayoutNode::Forward method to handle variable transformations more robustly

- Updated the method to check for a minimum number of input dimensions.
- Introduced a mechanism to transform the last InputDim() elements of the input variables.
- Concatenated transformed variables with the remaining input variables for a comprehensive output.

* Refactor LayoutNode::Forward method for improved readability

- Removed unnecessary whitespace to enhance code clarity.
- Maintained existing functionality while streamlining the transformation process of input variables.
parent 86aaf3c1
......@@ -115,13 +115,32 @@ Array<PrimExpr> LayoutNode::OutputShape() const {
Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
if (vars.empty())
return forward_index_;
ICHECK_EQ(vars.size(), InputDim());
ICHECK_GE(vars.size(), InputDim());
// Take the last InputDim() elements for transformation
Array<PrimExpr> transform_vars;
for (size_t i = vars.size() - InputDim(); i < vars.size(); i++) {
transform_vars.push_back(vars[i]);
}
Map<Var, PrimExpr> vmap;
for (size_t i = 0; i < InputDim(); i++) {
vmap.Set(InputPlaceholder(i), vars[i]);
vmap.Set(InputPlaceholder(i), transform_vars[i]);
}
return forward_index_.Map(
Array<PrimExpr> transformed = forward_index_.Map(
[&](const PrimExpr &e) { return Substitute(e, vmap); });
// Concatenate with the remaining elements from vars
Array<PrimExpr> result;
for (size_t i = 0; i < vars.size() - InputDim(); i++) {
result.push_back(vars[i]);
}
for (const auto &expr : transformed) {
result.push_back(expr);
}
return result;
}
Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment