/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file op/parallel.cc * \brief Define Parallel for operator */ #include "parallel.h" #include #include "../layout/utils.h" #include "../target/utils.h" #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" namespace tvm { namespace tl { using namespace tir; namespace attr { /*! \brief Mark that how the loop is vectorized. */ constexpr const char *coalesced_width = "coalesced_width"; } // namespace attr class IfBufferRemapLoopGenerator : public StmtExprMutator { public: static For run(Stmt stmt, Map buffer_remap, Map layout_map) { IfBufferRemapLoopGenerator generator(buffer_remap, layout_map); return Downcast(generator(std::move(stmt))); } private: IfBufferRemapLoopGenerator(Map buffer_remap, Map layout_map) : buffer_remap_(buffer_remap), layout_map_(layout_map) {} PrimExpr VisitExpr_(const BufferLoadNode *op) final { auto load = Downcast(StmtExprMutator::VisitExpr_(op)); if (buffer_remap_.count(load->buffer)) { auto new_indices = layout_map_[load->buffer]->Forward(load->indices); auto new_buffer = buffer_remap_[load->buffer]; return BufferLoad(new_buffer, new_indices); } return load; } Stmt VisitStmt_(const BufferStoreNode *op) final { auto store = Downcast(StmtExprMutator::VisitStmt_(op)); if (buffer_remap_.count(store->buffer)) { auto new_indices = layout_map_[store->buffer]->Forward(store->indices); auto new_buffer = buffer_remap_[store->buffer]; return BufferStore(new_buffer, store->value, new_indices); } return store; } Map buffer_remap_; Map layout_map_; }; void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) { ICHECK(op->kind == ForKind::kParallel); p->loop_vars_.push_back( IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar)); p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); StmtExprVisitor::VisitStmt_(op); } void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) { if (op->buffer.scope() == "local.fragment") { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) << op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer); } else { p->indice_map_.Set(op->buffer, op->indices); } p->buffer_is_write_.insert(op->buffer); } StmtExprVisitor::VisitStmt_(op); } void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) { if (op->buffer.scope() == "local.fragment") { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) << op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer); } else { p->indice_map_.Set(op->buffer, op->indices); } } StmtExprVisitor::VisitExpr_(op); } ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); } bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const { auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; }); return StructuralEqual()(indice_map_[buffer], common_indice); } /*! \brief Infer the layout for parallel operations based on different inference * levels * * The inference level controls how aggressively we try to infer and optimize * layouts: * - kStrict (2): Most conservative level. Only allows explicitly defined * layouts. Returns empty layout map if loop_layout_ is not already defined. * Used when exact layout control is required. * * - kCommon (1): Intermediate level between strict and free. * Allows common layout patterns while maintaining some * constraints. * * - kFree (0): Most permissive level. Allows maximum optimization freedom. * Will attempt layout inference even without source buffers. * Can generate new layouts based on vectorization and thread * bounds. Used when maximum performance optimization is desired. */ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { if (loop_layout_.defined()) return {}; if (level == InferLevel::kStrict) return {}; auto block_size = T.thread_bounds->extent; // Step 1: try to infer loop's partition from a source fragment Buffer source_buffer, read_source_buffer; for (const auto &[buffer, indices] : indice_map_) { if (T.layout_map.count(buffer)) { auto frag = T.layout_map[buffer].as().value(); if (buffer_is_write_.count(buffer)) { source_buffer = buffer; } else { // Keep the buffer with largest number of indices // (which means the inference based on that buffer is more accurate) // as read_source_buffer to get more accurate layout if (!read_source_buffer.defined() || indice_map_[buffer].size() > indice_map_[read_source_buffer].size()) { read_source_buffer = buffer; } // If the buffer is not replicated, use it as source_buffer // because the layout inference is more accurate if (is_one(frag->ReplicateExtent())) { source_buffer = buffer; } } } } auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) { Fragment src_layout = T.layout_map[buffer].as().value(); if (IsCommonAccessIndice(buffer)) { return src_layout; } else { Var rep; auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar); PrimExpr loop_var_to_thread = src_layout->ForwardThread(indice_map_[buffer], rep); return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) .SetThreadRange(T.thread_bounds); } }; if (source_buffer.defined()) { loop_layout_ = compute_loop_layout_from_buffer(source_buffer); } else if (level == InferLevel::kFree) { if (read_source_buffer.defined()) { loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer); // // Loop don't need to be replicated. // if (!is_one(loop_layout_->ReplicateExtent())) // loop_layout_ = loop_layout_->DeReplicate(); // // if still has replication, add a condition // if (!is_one(loop_layout_->ReplicateExtent())) { // auto inv = loop_layout_->Inverse(); // Array fwd; // for (size_t i = 0; i < loop_layout_->OutputDim(); i++) // fwd.push_back(0); // fwd.push_back(InputPlaceholder(0)); // auto rep = inv->Forward(fwd).back(); // AddPredicate(EQ(rep, 0)); // } } else { // Vectorize Size must be aware of the buffer_remap // As the pass will do post processing to the layout auto maybe_remapped_root_ = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); int vector_size = GetVectorizeSize(maybe_remapped_root_); // Check if coalesced_width is defined if (auto coalesced_width = root_->annotations.Get(tl::attr::coalesced_width)) { if (const auto *imm = coalesced_width.as()) { int expected = imm->value; // Verify that vector_size is divisible by expected if (vector_size % expected != 0) { LOG(FATAL) << "Vector size " << vector_size << " is not divisible by coalesced width " << expected; } vector_size = expected; } else { LOG(FATAL) << "coalesced_width should be an IntImmNode."; } } loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds); } PrimExpr loop_thread_extent = loop_layout_->ThreadExtent(); if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) AddPredicate( LT(InputPlaceholder(0) - T.thread_bounds->min, loop_thread_extent)); } else { return {}; } // Step 2: Check that the loop's partition can correctly align with all source // fragment for (const auto &[buffer, _] : indice_map_) { if (T.layout_map.count(buffer)) { auto fragment = T.layout_map[buffer].as().value(); // TODO: Add thread checks for replicated cases // need to wildcard match the rhs with lhs if (!is_one(loop_layout_->ReplicateExtent()) || !is_one(fragment->ReplicateExtent())) continue; auto vars = loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); auto lhs = loop_layout_->ForwardThread(vars, NullOpt); auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt); auto diff = analyzer_.Simplify(lhs - rhs); ICHECK(is_zero(diff)) << "Layout infer conflict for " << buffer << " " << source_buffer << "\nLHS = " << lhs << "\nRHS = " << rhs; } } // Step 3: Infer other fragment's layout from the loop's partition LayoutMap results; for (const auto &[buffer, _] : indice_map_) { if (!T.layout_map.count(buffer)) { results.Set(buffer, CompleteBufferFragment(buffer).SetThreadRange( T.thread_bounds)); } // Though they may exist some conflicts, but it's fine. // Layout infer conflict for local.fragment can noy be handled here // because the source_buffer is not always available if (buffer.scope() == "local.fragment" && source_buffer.defined() && source_buffer.scope() == "local.fragment") { if (T.layout_map.count(buffer)) { const FragmentNode *src_layout = T.layout_map[buffer].as().get(); Fragment dst_layout_fragment = CompleteBufferFragment(buffer).SetThreadRange(T.thread_bounds); const FragmentNode *dst_layout = dst_layout_fragment.as().get(); if (src_layout && dst_layout) { ICHECK(src_layout->IsEqual(dst_layout, true)) << "Layout may conflict with ParallelOp for buffer " << buffer << "\nError body begin:\n" << GetRoot()->body << "\nError body end" << "\nLHS = " << src_layout->DebugOutput() << "\nRHS = " << dst_layout->DebugOutput() << "\nYou may need to use a shared memory to transform the " "layout"; } } } } return results; } Optional ParallelOp::GetPredicate(Var thread_var) const { if (predicate_.defined()) { return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}}); } else { return NullOpt; } } Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) { ICHECK(loop_layout_.defined()); if (IsCommonAccessIndice(buffer)) return loop_layout_; PrimExpr rep_b = MakeFlattenedExpression( DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_)); auto bijective_indice = indice_map_[buffer]; bijective_indice.push_back(rep_b); Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse(); PrimExpr indice_rep_extent = ind_inv->InputShape().back(); // this is the size of rep_b PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent; Array fwd; for (size_t i = 0; i < buffer->shape.size(); i++) { fwd.push_back(InputPlaceholder(i)); } fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent)); PrimExpr thd_b = loop_layout_->ForwardThread( ind_inv->Forward(fwd), FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt) ->CondenseReplicateVar(); } } // namespace tl } // namespace tvm