"docs/source/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "dda6a126463b3d560f2525e8daea2f0c7be9f56f"
Commit 444b7c4e authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Enhance layout inference pass for flexibility (#550)

* Enhance Layout

* strict update

* lint fix

* Refactor layout inference by removing unnecessary logging statements in `parallel.cc` and `layout_inference.cc`. This cleanup enhances code readability and reduces log clutter during layout inference steps.

* lint fix

* Refactor file copying logic in setup.py to simplify directory creation and file copying process. Removed unnecessary existence check before copying source files to the target directory.
parent 319bc6b1
......@@ -348,11 +348,7 @@ class TileLangBuilPydCommand(build_py):
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
if not os.path.exists(os.path.join(target_dir, os.path.basename(source_dir))):
# if not exists, copy the file
# as tox will copy the file to the build
# directory based on manifest file
shutil.copy2(source_dir, target_dir)
shutil.copy2(source_dir, target_dir)
# copy the tl_templates
TILELANG_SRC = [
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file op/parallel.cc
* \brief Define Parallel for operator
......@@ -162,9 +143,10 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
indice_map_[read_source_buffer].size()) {
read_source_buffer = buffer;
}
// If the buffer is not replicated, use it as source_buffer
// because the layout inference is more accurate
if (is_one(frag->ReplicateExtent())) {
// If the buffer is not replicated and shape is equal to the
// source_buffer, use it as source_buffer because the layout inference
// is more accurate
if (is_one(frag->ReplicateExtent()) && !source_buffer.defined()) {
source_buffer = buffer;
}
}
......@@ -275,7 +257,6 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
results.Set(buffer, CompleteBufferFragment(buffer)->BindThreadRange(
T.thread_bounds));
}
// Though they may exist some conflicts, but it's fine.
// Layout infer conflict for local.fragment can noy be handled here
// because the source_buffer is not always available
......@@ -288,6 +269,13 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
const FragmentNode *dst_layout =
dst_layout_fragment.as<Fragment>().get();
if (as_const_int(dst_layout->ReplicateExtent()) &&
as_const_int(src_layout->ReplicateExtent()) &&
(*as_const_int(dst_layout->ReplicateExtent()) >
*as_const_int(src_layout->ReplicateExtent()))) {
results.Set(buffer, dst_layout_fragment);
continue;
}
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Layout may conflict with ParallelOp for buffer " << buffer
......@@ -314,21 +302,18 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
ICHECK(loop_layout_.defined());
if (IsCommonAccessIndice(buffer))
if (IsCommonAccessIndice(buffer)) {
return loop_layout_;
}
PrimExpr rep_b = MakeFlattenedExpression(
DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b);
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
PrimExpr indice_rep_extent =
ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
Array<PrimExpr> fwd;
for (size_t i = 0; i < buffer->shape.size(); i++) {
fwd.push_back(InputPlaceholder(i));
......@@ -337,7 +322,6 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv->Forward(fwd),
FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar();
}
......
......@@ -285,6 +285,21 @@ public:
ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
if (layout_map.count(buffer)) {
// If replicate size of this buffer is greater than the old one
if (buffer.scope() == "local.fragment" &&
level != InferLevel::kStrict) {
const FragmentNode *dst_layout = layout.as<Fragment>().get();
const FragmentNode *src_layout =
layout_map[buffer].as<Fragment>().get();
if (as_const_int(dst_layout->ReplicateExtent()) &&
as_const_int(src_layout->ReplicateExtent()) &&
(*as_const_int(dst_layout->ReplicateExtent()) >
*as_const_int(src_layout->ReplicateExtent()))) {
// update map
layout_map.Set(buffer, layout);
continue;
}
}
// If already in map, ensure they are structurally equal
ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer
......
......@@ -65,6 +65,8 @@ def compile(
"tl.dynamic_vectorize_size_bits": int, default: 128
"tl.disable_safe_memory_legalize": bool, default: False
"""
assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
return cached(
func=func,
out_idx=out_idx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment