"src/tl_templates/vscode:/vscode.git/clone" did not exist on "eb41574431608e2a96d3d8941f9c1e6d775f228e"
Commit c39e540a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Improve Thread Variable Handling in Layout Inference (#179)

* [Refactor] Improve Thread Variable Handling in Layout Inference

- Update layout inference to handle thread variables more robustly
- Add explicit size check between infer_list_ and thread_var_vec_
- Modify thread variable access to use per-iteration thread variable
- Simplify thread predicate retrieval logic
- Add minor code cleanup and return variable assignment

* [Refactor] Update Layout Inference Copyright and Simplify Return Logic

- Replace Apache License header with Microsoft Corporation copyright notice
- Simplify LayoutInference function by directly returning substituted function
- Remove unnecessary variable assignment in return statement

* [Refactor] Update Layout Inference Copyright to Tile-AI Corporation

- Change copyright notice from Microsoft Corporation to Tile-AI Corporation
- Maintain existing file structure and licensing header
parent 8ccf6ea2
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file layout_inference.cc
* \brief infer the fragment/shared memory layout
......@@ -225,7 +206,12 @@ public:
// Collect layout info for For nodes
Map<For, Fragment> for_map;
Map<For, PrimExpr> predicate_map;
for (auto &base_infer : infer_list_) {
ICHECK(infer_list_.size() == thread_var_vec_.size())
<< "infer_list_ and thread_var_vec_ size mismatch";
for (int i = 0; i < infer_list_.size(); i++) {
std::unique_ptr<Operator> base_infer = std::move(infer_list_[i]);
auto thread_var = thread_var_vec_[i];
// Check if base_infer is valid
ICHECK(base_infer != nullptr) << "Null pointer encountered in "
"infer_list_ while collecting for_map.";
......@@ -238,10 +224,10 @@ public:
for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
// thread_var_ should be defined if we rely on it
ICHECK(thread_var_.defined())
<< "thread_var_ is not defined. Cannot retrieve predicate.";
ICHECK(thread_var.defined())
<< "thread_var is not defined. Cannot retrieve predicate.";
if (auto predicate = for_infer->GetPredicate(thread_var_->var)) {
if (auto predicate = for_infer->GetPredicate(thread_var->var)) {
predicate_map.Set(for_infer->GetRoot(), predicate.value());
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment