Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
core_gather_headers()
gather_srcs(cinnapi_src SRCS base_group_scheduler.cc)
gather_srcs(cinnapi_src SRCS st_shape_group_scheduler.cc)
gather_srcs(cinnapi_src SRCS dy_shape_group_scheduler.cc)
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h"
#include "paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h"
#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h"
namespace cinn {
namespace ir {
std::unique_ptr<GroupScheduler> GroupScheduler::Make(
ir::IRSchedule* ir_sch,
const std::unordered_set<std::string>& output_tensor_names,
const common::Target& target,
bool is_dy_shape) {
if (is_dy_shape) {
return std::make_unique<DynamicShapeGroupScheduler>(
ir_sch, output_tensor_names, target);
} else {
return std::make_unique<StaticShapeGroupScheduler>(
ir_sch, output_tensor_names, target);
}
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule_block_graph.h"
namespace cinn {
namespace ir {
using SymbolicPredicate = Expr;
/**
* The base class used for scheduling fusion groups.
*/
class GroupScheduler {
public:
GroupScheduler(ir::IRSchedule* ir_sch,
const std::unordered_set<std::string>& output_tensor_names,
const common::Target& target)
: ir_sch_(ir_sch),
output_tensor_names_(output_tensor_names),
target_(target) {
schedule_block_graph_ = std::make_unique<ir::ScheduleBlockGraph>(*ir_sch_);
}
static std::unique_ptr<GroupScheduler> Make(
ir::IRSchedule* ir_sch,
const std::unordered_set<std::string>& output_tensor_names,
const common::Target& target,
bool is_dy_shape = false);
virtual ~GroupScheduler() = default;
virtual void Schedule() = 0;
virtual std::vector<std::pair<SymbolicPredicate, ir::Expr>> GetIRs() = 0;
protected:
ir::IRSchedule* ir_sch_;
const std::unordered_set<std::string>& output_tensor_names_;
const common::Target& target_;
// Graph in units of ScheduleBlockNode, each node corresponds to a
// ScheduleBlock in IR.
std::unique_ptr<ir::ScheduleBlockGraph> schedule_block_graph_;
};
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h"
namespace cinn {
namespace ir {
void DynamicShapeGroupScheduler::Schedule() {
// Fake schedule for test
int max_spacial_numel = 1;
ScheduleBlockNode* node = schedule_block_graph_->EndPoints()[0];
ir::Expr block_realize = node->Block();
std::vector<ir::Expr> loops = ir_sch_->GetLoops(block_realize);
ir::Expr extent = loops[0].As<ir::For>()->extent;
ir::Expr predicate1 = ir::LE::Make(extent, Expr(1024));
std::unique_ptr<ir::IRSchedule> new_ir_sch1 =
std::make_unique<ir::IRSchedule>(*ir_sch_);
ScheduleBlockGraph sbg1(*new_ir_sch1);
sbg1.NodesWalk([&](ir::ScheduleBlockNode* node) {
std::vector<cinn::ir::Expr> splited_loops =
new_ir_sch1->Split(new_ir_sch1->GetLoops(node->Block())[0], {-1, 1});
new_ir_sch1->Bind(splited_loops[1], "blockIdx.x");
new_ir_sch1->Bind(new_ir_sch1->GetLoops(node->Block())[2], "threadIdx.x");
});
ir_schs_.emplace_back(predicate1, std::move(new_ir_sch1));
ir::Expr predicate2 = ir::GT::Make(extent, Expr(1024));
std::unique_ptr<ir::IRSchedule> new_ir_sch2 =
std::make_unique<ir::IRSchedule>(*ir_sch_);
ScheduleBlockGraph sbg2(*new_ir_sch2);
sbg2.NodesWalk([&](ir::ScheduleBlockNode* node) {
std::vector<cinn::ir::Expr> splited_loops =
new_ir_sch2->Split(new_ir_sch2->GetLoops(node->Block())[0], {-1, 1024});
new_ir_sch2->Bind(splited_loops[1], "blockIdx.x");
new_ir_sch2->Bind(new_ir_sch2->GetLoops(node->Block())[2], "threadIdx.x");
});
ir_schs_.emplace_back(predicate2, std::move(new_ir_sch2));
}
std::vector<std::pair<SymbolicPredicate, ir::Expr>>
DynamicShapeGroupScheduler::GetIRs() {
std::vector<std::pair<SymbolicPredicate, ir::Expr>> irs;
for (auto& sch_pair : ir_schs_) {
irs.emplace_back(sch_pair.first,
sch_pair.second->GetModule().GetExprs()[0]);
}
return irs;
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h"
namespace cinn {
namespace ir {
/**
* The class used for scheduling fusion groups with dynamic shape.
* Note: Currently only CUDA backend is supported.
*/
class DynamicShapeGroupScheduler : public GroupScheduler {
public:
DynamicShapeGroupScheduler(
ir::IRSchedule* ir_sch,
const std::unordered_set<std::string>& output_tensor_names,
const common::Target& target)
: GroupScheduler(ir_sch, output_tensor_names, target) {}
void Schedule() override;
std::vector<std::pair<SymbolicPredicate, ir::Expr>> GetIRs() override;
private:
std::vector<std::pair<SymbolicPredicate, std::unique_ptr<ir::IRSchedule>>>
ir_schs_;
};
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
namespace cinn {
namespace ir {
static const std::unordered_set<std::string>
kProhibitScheduleExternalFuncNames = {
#define CINN_NVGPU_FUNC2STRING(str) #str
#define CINN_NVGPU_FUNC_TYPE(FUNC, TYPE) \
CINN_NVGPU_FUNC2STRING(cinn_nvgpu_##FUNC##TYPE)
#define GEN_FUNC_NAME(_, impl) \
_(impl, gt_num) \
_(impl, lt_num) \
_(impl, index_add) \
_(impl, next_smallest)
#define GEN_FUNC_NAME_WITH_TYPE(_, ...) \
_(__VA_ARGS__, _bool), _(__VA_ARGS__, _fp16), _(__VA_ARGS__, _fp32), \
_(__VA_ARGS__, _fp64), _(__VA_ARGS__, _uint8), _(__VA_ARGS__, _int8), \
_(__VA_ARGS__, _int16), _(__VA_ARGS__, _int32), _(__VA_ARGS__, _int64),
GEN_FUNC_NAME(GEN_FUNC_NAME_WITH_TYPE, CINN_NVGPU_FUNC_TYPE)
#undef GEN_FUNC_NAME
};
bool IsProhibitScheduleExternCallBlock(ir::Expr block) {
ir::ScheduleBlockRealize* sch_block_realize =
block.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(sch_block_realize);
ir::ScheduleBlock* sch_block =
sch_block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(sch_block);
auto find_call = ir::ir_utils::CollectIRNodesWithoutTensor(
sch_block->body, [&](const Expr* x) { return x->As<ir::Call>(); });
for (ir::Expr call : find_call) {
ir::Call* call_node = call.As<ir::Call>();
if (call.As<ir::Call>() && kProhibitScheduleExternalFuncNames.count(
call.As<ir::Call>()->name) != 0) {
return true;
}
}
return false;
}
// Find loops with same extents of 2 ScheduleBlock
std::vector<std::tuple<ir::Expr, ir::Expr>> FindSameOuterLoops(
ir::ScheduleBlockNode* source_node, ir::ScheduleBlockNode* target_node) {
std::vector<ir::Expr> src_ctrl_stmts = source_node->ControlStmts();
std::vector<ir::Expr> tgt_ctrl_stmts = target_node->ControlStmts();
std::vector<std::tuple<ir::Expr, ir::Expr>> same_loops;
int min_stmt_size = std::min(src_ctrl_stmts.size(), tgt_ctrl_stmts.size());
for (int i = 0; i < min_stmt_size; ++i) {
if (src_ctrl_stmts[i].As<ir::For>() && tgt_ctrl_stmts[i].As<ir::For>() &&
ir::GetLoopExtent(src_ctrl_stmts[i]) ==
GetLoopExtent(tgt_ctrl_stmts[i])) {
same_loops.push_back(
std::make_tuple(src_ctrl_stmts[i], tgt_ctrl_stmts[i]));
} else {
break;
}
}
return same_loops;
}
std::unordered_set<std::string> GetReduceLoopVarNames(ir::Expr block) {
ir::ScheduleBlockRealize* schedule_block_realize =
block.As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* schedule_block =
schedule_block_realize->schedule_block.As<ir::ScheduleBlock>();
std::vector<ir::Expr> iter_values = schedule_block_realize->iter_values;
std::vector<ir::Var> iter_vars = schedule_block->iter_vars;
std::unordered_set<std::string> reduce_loop_var_names;
for (int i = 0; i < iter_vars.size(); ++i) {
if (iter_vars[i]->is_reduce_axis) {
ir::ir_utils::CollectIRNodesWithoutTensor(
iter_values[i], [&](const ir::Expr* x) {
if (x->as_var()) {
reduce_loop_var_names.insert(x->as_var_ref()->name);
}
return false;
});
}
}
return reduce_loop_var_names;
}
std::unordered_set<std::string> GetReduceVarNames(ir::Expr block) {
ir::ScheduleBlockRealize* schedule_block_realize =
block.As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* schedule_block =
schedule_block_realize->schedule_block.As<ir::ScheduleBlock>();
std::vector<ir::Var> iter_vars = schedule_block->iter_vars;
std::unordered_set<std::string> reduce_var_names;
for (int i = 0; i < iter_vars.size(); ++i) {
if (iter_vars[i]->is_reduce_axis) {
reduce_var_names.insert(iter_vars[i]->name);
}
}
return reduce_var_names;
}
void StaticShapeGroupScheduler::Schedule() {
feasible_conditions_.emplace_back(
&StaticShapeGroupScheduler::IsKeepGraphDependency);
DoLoopAlignment();
DoComputeInline();
#ifdef CINN_WITH_CUDA
OptimizeReduction();
#endif
DoHorizontalLoopFusion();
DoVerticalLoopFusion();
#ifdef CINN_WITH_CUDA
BindCudaAxis();
AllocateStorage();
#endif
}
void StaticShapeGroupScheduler::MapExprSchedule() {
DoComputeInline();
#ifdef CINN_WITH_CUDA
AllocateStorage();
#endif
}
std::vector<std::pair<SymbolicPredicate, ir::Expr>>
StaticShapeGroupScheduler::GetIRs() {
return {{Expr(1), ir_sch_->GetModule().GetExprs()[0]}};
}
NodePriority StaticShapeGroupScheduler::CalculateNodePriority(
const ir::ScheduleBlockNode* node) const {
bool has_loop_binded = false;
std::unordered_set<std::string> reduce_loop_var_names =
GetReduceLoopVarNames(node->Block());
int64_t reduce_score = 1;
double score = 1;
for (Expr expr : node->ControlStmts()) {
ir::For* for_node = expr.As<ir::For>();
if (for_node != nullptr) {
score *= ir::GetLoopExtent(expr);
}
if (reduce_loop_var_names.count(for_node->loop_var->name) != 0) {
reduce_score *= ir::GetLoopExtent(expr);
}
if (for_node->is_binded()) {
has_loop_binded = true;
}
}
if (reduce_score > 1) {
score *= (reduce_score * std::log2(reduce_score));
}
VLOG(6) << "The priority score of node " << node->id() << " is " << score;
VLOG(6) << "The node has_loop_binded: " << has_loop_binded;
return NodePriority{has_loop_binded, score};
}
ir::ScheduleBlockNode* StaticShapeGroupScheduler::FindGlobalMasterNode() const {
NodePriority max{false, std::numeric_limits<int64_t>::min()};
ir::ScheduleBlockNode* master = nullptr;
auto FindMaster = [&](ir::ScheduleBlockNode* node) {
NodePriority priority = CalculateNodePriority(node);
VLOG(6) << "The priority score of node " << node->id() << " is "
<< priority.score
<< ", has_loop_binded: " << priority.has_loop_binded;
if (max < priority) {
max = priority;
master = node;
}
};
schedule_block_graph_->NodesWalk(FindMaster);
CHECK(master) << "Cannot find global master node";
VLOG(6) << "Find the global master node: " << master->id();
return master;
}
std::unordered_set<std::string> StaticShapeGroupScheduler::OutputTensorNames()
const {
std::unordered_set<std::string> output_tensor_names{output_tensor_names_};
for (ir::ScheduleBlockNode* node : schedule_block_graph_->EndPoints()) {
output_tensor_names.insert(node->id());
}
return output_tensor_names;
}
void StaticShapeGroupScheduler::DoLoopAlignment() {
VLOG(5) << "[Start LoopAlignment] func body: "
<< ir_sch_->GetModule().GetExprs().front();
ir::ScheduleBlockNode* global_master = FindGlobalMasterNode();
ir::Expr master_block = global_master->Block();
std::vector<int> original_master_loop_extents;
std::vector<int> spacial_master_loop_extents;
std::vector<int> original_master_loop_order;
std::vector<int> recover_loop_order;
std::vector<ir::Expr> master_iter_values =
master_block.As<ir::ScheduleBlockRealize>()->iter_values;
std::vector<ir::Var> master_iter_vars =
master_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
std::vector<ir::Expr> master_loops = ir_sch_->GetLoops(master_block);
std::unordered_set<std::string> reduce_var_names =
GetReduceVarNames(master_block);
if (!reduce_var_names.empty()) {
std::set<ir::Expr> reduce_loads = ir::ir_utils::CollectIRNodesWithoutTensor(
master_block,
[&](const ir::Expr* x) {
bool find_reduce_var = false;
if (x->As<ir::Load>()) {
int i = 0;
for (ir::Expr index : x->As<ir::Load>()->indices) {
if (index.as_var() &&
reduce_var_names.count(index.as_var_ref()->name) > 0) {
find_reduce_var = true;
}
++i;
}
}
return find_reduce_var;
},
/* uniq_target = */ true);
CHECK_EQ(reduce_loads.size(), 1);
std::vector<ir::Expr> indices =
reduce_loads.begin()->As<ir::Load>()->indices;
for (ir::Expr index : indices) {
CHECK_NOTNULL(index.as_var());
int idx = 0;
bool is_reduce_var = false;
for (const ir::Var& iter_var : master_iter_vars) {
if (iter_var->name == index.as_var_ref()->name) {
is_reduce_var = iter_var->is_reduce_axis;
break;
}
++idx;
}
std::vector<ir::Var> loop_vars_in_order;
ir::ir_utils::CollectIRNodesInOrder(
master_iter_values[idx], [&](const ir::Expr* x) {
if (x->as_var()) {
loop_vars_in_order.push_back(x->as_var_ref());
}
return false;
});
for (const ir::Var& loop_var : loop_vars_in_order) {
for (int i = 0; i < master_loops.size(); ++i) {
if (master_loops[i].As<ir::For>()->loop_var->name == loop_var->name) {
original_master_loop_order.push_back(i);
int extent = ir::GetLoopExtent(master_loops[i]);
original_master_loop_extents.push_back(extent);
if (!is_reduce_var) {
spacial_master_loop_extents.push_back(extent);
}
}
}
}
}
for (int i = 0; i < original_master_loop_order.size(); ++i) {
for (int j = 0; j < original_master_loop_order.size(); ++j) {
if (original_master_loop_order[j] == i) {
recover_loop_order.push_back(j);
break;
}
}
}
CHECK_EQ(original_master_loop_order.size(), recover_loop_order.size());
} else {
for (int i = 0; i < master_loops.size(); ++i) {
original_master_loop_extents.push_back(
ir::GetLoopExtent(master_loops[i]));
spacial_master_loop_extents.push_back(ir::GetLoopExtent(master_loops[i]));
original_master_loop_order.push_back(i);
recover_loop_order.push_back(i);
}
}
int total_master_loop_extents = 1;
int total_spacial_loop_extents = 1;
for (int extent : original_master_loop_extents) {
total_master_loop_extents *= extent;
}
for (int extent : spacial_master_loop_extents) {
total_spacial_loop_extents *= extent;
}
auto LoopAlignmentFunc = [&](ir::ScheduleBlockNode* node) {
if (IsProhibitScheduleExternCallBlock(node->Block())) {
return false;
}
if (node == global_master) {
return false;
}
for (ir::Expr expr : node->ControlStmts()) {
if (expr.As<ir::For>() != nullptr &&
(expr.As<ir::For>()->for_type() == ir::ForType::GPUBlock ||
expr.As<ir::For>()->for_type() == ir::ForType::GPUThread)) {
return false;
}
if (expr.As<ir::For>()->body.As<ir::Block>() &&
expr.As<ir::For>()->body.As<ir::Block>()->stmts.size() > 1) {
return false;
}
}
VLOG(6) << "try to align loops of block: " << node->id()
<< " with block: " << global_master->id();
// 1. Fuse source loops
ir::Expr source_loop = ir_sch_->Fuse(node->ControlStmts());
int total_source_extent = ir::GetLoopExtent(source_loop);
// 2. Split source loop to align with the target loops
std::vector<int> target_loop_extents;
if (total_source_extent < total_spacial_loop_extents) {
int cur_extent = 1;
for (int extent : spacial_master_loop_extents) {
cur_extent *= extent;
if (cur_extent == total_source_extent) {
target_loop_extents.push_back(extent);
break;
} else if (cur_extent > total_source_extent) {
target_loop_extents.push_back(-1);
break;
} else {
target_loop_extents.push_back(extent);
}
}
} else if (total_source_extent == total_spacial_loop_extents) {
target_loop_extents = spacial_master_loop_extents;
} else if (total_source_extent < total_master_loop_extents) {
target_loop_extents = spacial_master_loop_extents;
target_loop_extents.push_back(-1);
} else if (total_source_extent == total_master_loop_extents) {
target_loop_extents = original_master_loop_extents;
}
std::vector<ir::Expr> source_loops;
if (target_loop_extents.size() > 0 &&
target_loop_extents[0] < total_source_extent) {
source_loops = ir_sch_->Split(source_loop, target_loop_extents);
} else {
source_loops = {source_loop};
}
// 3. Rerorder loops to match the target loops
if (total_source_extent == total_master_loop_extents) {
ir_sch_->Reorder(node->id(), recover_loop_order);
}
return true;
};
schedule_block_graph_->DFSTopoWalk(LoopAlignmentFunc);
VLOG(5) << "[After LoopAlignment] func body: "
<< ir_sch_->GetModule().GetExprs().front();
}
void StaticShapeGroupScheduler::DoComputeInline() {
VLOG(5) << "[Start DoComputeInline] func body: "
<< ir_sch_->GetModule().GetExprs().front();
std::unordered_set<std::string> no_inline_output_names = OutputTensorNames();
auto_schedule::AutoInline inliner(target_, no_inline_output_names);
auto InlineFunc = [&](ir::ScheduleBlockNode* node) {
if (IsProhibitScheduleExternCallBlock(node->Block())) {
return;
}
VLOG(6) << "try ComputeInline on: " << node->id()
<< ", before ComputeInline, func body: "
<< ir_sch_->GetModule().GetExprs().front();
ir::Expr schedule_block = node->Block();
inliner.Apply(ir_sch_, schedule_block);
VLOG(6) << "try ComputeInline on: " << node->id()
<< ", after ComputeInline, func body: "
<< ir_sch_->GetModule().GetExprs().front();
};
schedule_block_graph_->DFSTopoWalk(InlineFunc);
schedule_block_graph_->Update(*ir_sch_);
VLOG(5) << "[After DoComputeInline] func body: "
<< ir_sch_->GetModule().GetExprs().front();
}
void StaticShapeGroupScheduler::DoHorizontalLoopFusion() {
VLOG(5) << "[Start DoHorizontalLoopFusion] func body: "
<< ir_sch_->GetModule().GetExprs().front();
std::vector<ir::ScheduleBlockNode*> end_nodes =
schedule_block_graph_->EndPoints();
std::reverse(end_nodes.begin(), end_nodes.end());
ir::ScheduleBlockNode* master_node = end_nodes.front();
CHECK_NOTNULL(master_node);
for (int i = 1; i < end_nodes.size(); ++i) {
if (IsProhibitScheduleExternCallBlock(end_nodes[i]->Block())) {
continue;
}
VLOG(6) << "try to fuse loop of " << end_nodes[i]->id() << " to "
<< master_node->id();
std::vector<std::tuple<cinn::ir::Expr, cinn::ir::Expr>>&& same_loops =
FindSameOuterLoops(end_nodes[i], master_node);
if (same_loops.size() == 0) {
continue;
}
ir::Expr target_loop = std::get<1>(same_loops.back());
VLOG(6) << "target_loop: " << target_loop;
ir_sch_->SimpleComputeAt(end_nodes[i]->Block(), target_loop);
VLOG(6) << "after fuse: " << ir_sch_->GetModule().GetExprs().front();
}
VLOG(5) << "[After DoHorizontalLoopFusion] func body: "
<< ir_sch_->GetModule().GetExprs().front();
}
void StaticShapeGroupScheduler::DoVerticalLoopFusion() {
VLOG(5) << "[Start DoVerticalLoopFusion] func body: "
<< ir_sch_->GetModule().GetExprs().front();
UpdateBlockOrder();
auto FindMaster =
[&](ir::ScheduleBlockNode* node) -> std::vector<ir::ScheduleBlockNode*> {
std::vector<ir::ScheduleBlockNode*> masters = node->Consumers();
std::sort(
masters.begin(),
masters.end(),
[&](const ir::ScheduleBlockNode* a, const ir::ScheduleBlockNode* b) {
return this->CalculateNodePriority(b) <
this->CalculateNodePriority(a);
});
return masters;
};
auto ComputeAtFunc = [&](ir::ScheduleBlockNode* node) {
if (IsProhibitScheduleExternCallBlock(node->Block())) {
return;
}
std::vector<ir::ScheduleBlockNode*> masters = FindMaster(node);
if (masters.size() == 0) {
return;
}
ir::Expr target_loop;
bool find_target_loop = false;
// Collect infomation of original loops
std::vector<ir::Expr> original_ctrl_stmts = node->ControlStmts();
int64_t original_total_loop_extent = 1;
std::vector<std::pair<std::string, int>> original_loop_infos;
std::unordered_set<ir::IrNode*> original_loop_node_ptrs;
for (ir::Expr stmt : original_ctrl_stmts) {
if (stmt.As<ir::For>()) {
int extent = ir::GetLoopExtent(stmt);
original_total_loop_extent *= extent;
std::string thread_axis = "";
ir::ForType target_for_type = stmt.As<ir::For>()->for_type();
if (target_for_type == ir::ForType::GPUBlock) {
thread_axis += "blockIdx.";
} else if (target_for_type == ir::ForType::GPUThread) {
thread_axis += "threadIdx.";
} else {
original_loop_infos.push_back(std::make_pair(thread_axis, extent));
continue;
}
int offset = stmt.As<ir::For>()->bind_info().offset;
thread_axis += ('x' + offset);
original_loop_infos.push_back(std::make_pair(thread_axis, extent));
original_loop_node_ptrs.insert(stmt.ptr());
}
}
std::unordered_set<std::string> src_reduce_loop_var_names =
GetReduceLoopVarNames(node->Block());
for (ir::ScheduleBlockNode* master : masters) {
// Find the target loop candidates;
std::vector<ir::Expr> target_loop_candidates;
int64_t total_loop_extent = 1;
std::unordered_set<std::string> tgt_reduce_loop_var_names =
GetReduceLoopVarNames(master->Block());
std::vector<std::tuple<cinn::ir::Expr, cinn::ir::Expr>> same_loops =
FindSameOuterLoops(node, master);
for (const std::tuple<cinn::ir::Expr, cinn::ir::Expr>& same_loop :
same_loops) {
ir::Expr source_loop = std::get<0>(same_loop);
ir::Expr target_loop = std::get<1>(same_loop);
bool is_src_loop_reduce =
src_reduce_loop_var_names.count(
source_loop.As<ir::For>()->loop_var->name) > 0;
bool is_tgt_loop_reduce =
tgt_reduce_loop_var_names.count(
target_loop.As<ir::For>()->loop_var->name) > 0;
if (source_loop.ptr() != target_loop.ptr() && !is_src_loop_reduce &&
!is_tgt_loop_reduce) {
target_loop_candidates.push_back(target_loop);
}
}
// Find the target loop with the highest priority and passing the
// feasibility condition check
for (std::vector<ir::Expr>::reverse_iterator iter =
target_loop_candidates.rbegin();
iter != target_loop_candidates.rend();
++iter) {
ir::Expr candidate_loop = *iter;
if (candidate_loop.As<ir::For>() &&
this->MeetConditions(node->Block(), candidate_loop, 0)) {
target_loop = candidate_loop;
find_target_loop = true;
break;
}
}
if (find_target_loop) {
VLOG(6) << "try to fuse loop of " << node->id() << " to "
<< master->id();
break;
}
}
// Do schedule
if (find_target_loop) {
ir_sch_->SimpleComputeAt(node->Block(), target_loop);
VLOG(6) << "after compute at: " << ir_sch_->GetModule().GetExprs()[0];
std::vector<ir::Expr> new_stmts = node->ControlStmts();
for (int idx = 0; idx < original_loop_infos.size(); ++idx) {
if (original_loop_infos[idx].first.empty()) {
continue;
}
if (idx < new_stmts.size()) {
CHECK(new_stmts[idx].As<ir::For>());
if (new_stmts[idx].As<ir::For>()->is_serial()) {
ir_sch_->Bind(new_stmts[idx], original_loop_infos[idx].first);
}
} else {
ir::Expr unit_loop = ir_sch_->AddUnitLoop(node->Block());
ir_sch_->Bind(unit_loop, original_loop_infos[idx].first);
}
}
VLOG(6) << "after loop info copy: " << ir_sch_->GetModule().GetExprs()[0];
// Update block and control stmts order after schedule.
this->UpdateBlockOrder();
} else {
LOG(INFO) << "Cannot find a loop of masters to ComputeAt, do not merge.\n"
<< "The schedule block: " << node->Block();
}
};
schedule_block_graph_->DFSTopoWalk(ComputeAtFunc);
VLOG(5) << "[After DoVerticalLoopFusion] func body: "
<< ir_sch_->GetModule().GetExprs().front();
}
void StaticShapeGroupScheduler::BindCudaAxis() {
if (target_.arch != Target::Arch::NVGPU) return;
VLOG(5) << "[Start BindCudaAxis] func body: "
<< ir_sch_->GetModule().GetExprs().front();
auto_schedule::AutoBind binder(target_);
auto BindFunc = [&](ir::ScheduleBlockNode* node) {
if (IsProhibitScheduleExternCallBlock(node->Block())) {
return;
}
VLOG(6) << "try bind cuda axis on: " << node->id()
<< ", before bind, func body: "
<< ir_sch_->GetModule().GetExprs().front();
binder.Apply(ir_sch_, node->id());
VLOG(6) << "try bind cuda axis on: " << node->id()
<< ", after bind, func body: "
<< ir_sch_->GetModule().GetExprs().front();
};
schedule_block_graph_->DFSTopoWalk(BindFunc);
VLOG(5) << "[After BindCudaAxis] func body: "
<< ir_sch_->GetModule().GetExprs().front();
}
struct Range {
int min;
int max;
};
std::ostream& operator<<(std::ostream& os, const Range& x) {
os << "(" << x.min << ", " << x.max << ")";
return os;
}
// TODO(BiynXu): After implementing auxiliary data structures such as IntegerSet
// and MultiDimIntegerSet, re implement this function to simplify these ugly
// codes.
void StaticShapeGroupScheduler::AllocateStorage() {
if (target_.arch != Target::Arch::NVGPU) return;
VLOG(5) << "[Start AllocateStorage] func body: "
<< ir_sch_->GetModule().GetExprs().front();
// Record ir::For using index structure: <block_name, <var_name, for_node>>
std::unordered_map<std::string, std::unordered_map<std::string, ir::Expr>>
for_map;
std::unordered_set<std::string> sync_mark;
// function to update for_map
auto UpdateVarNameToForMap = [&](ir::Expr root) {
std::vector<ir::Expr> all_blocks = ir_sch_->GetAllBlocks();
for (const ir::Expr& block : all_blocks) {
std::string block_name = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
std::vector<ir::Expr> for_expr = ir_sch_->GetLoops(block);
for (ir::Expr for_expr : for_expr) {
for_map[block_name][for_expr.As<ir::For>()->loop_var->name] = for_expr;
VLOG(6) << "for_map.insert: <" << block_name << ", "
<< for_expr.As<ir::For>()->loop_var->name << ">";
}
}
};
// function to analyze and flatten indices to one dim of load_or_store node
auto AnalyzeIndiceValue = [](ir::Expr load_or_store,
ir::Expr block) -> ir::Expr {
std::vector<ir::Expr> indices;
ir::Tensor tensor;
if (load_or_store.As<ir::Load>()) {
indices = load_or_store.As<ir::Load>()->indices;
tensor = load_or_store.As<ir::Load>()->tensor.as_tensor_ref();
} else {
indices = load_or_store.As<ir::Store>()->indices;
tensor = load_or_store.As<ir::Store>()->tensor.as_tensor_ref();
}
std::vector<ir::Var> iter_vars =
block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
std::vector<ir::Expr> iter_values =
block.As<ir::ScheduleBlockRealize>()->iter_values;
struct VarHash {
size_t operator()(const ir::Var& var) const {
std::string name = var->name;
return std::hash<std::string>()(name);
}
};
std::vector<int> strides;
int extent = 1;
for (int idx = tensor->shape.size() - 1; idx >= 0; --idx) {
strides.insert(strides.begin(), extent);
tensor->shape[idx] = common::AutoSimplify(tensor->shape[idx]);
CHECK(tensor->shape[idx].is_constant())
<< "Shape of tensor: " << tensor << " is not constant";
extent *= tensor->shape[idx].get_constant();
}
ir::Expr flatten_indice(0);
for (int idx = 0; idx < indices.size(); ++idx) {
flatten_indice = flatten_indice + ir::Expr(strides[idx]) * indices[idx];
}
flatten_indice = common::AutoSimplify(flatten_indice);
for (int idx = 0; idx < iter_vars.size(); ++idx) {
optim::ReplaceVarWithExpr(
&flatten_indice, iter_vars[idx], iter_values[idx]);
}
flatten_indice = common::AutoSimplify(flatten_indice);
VLOG(6) << "flatten_indice of " << load_or_store << " : " << flatten_indice;
return flatten_indice;
};
enum class CudaBindInfo : int {
kCudaBlock,
kCudaThread,
kSerial,
kCudaThreadAndSerial,
};
// function to calculate the range of the specified CUDA axis in a indice
// expression
auto CalculateRange = [&for_map](ir::Expr indice_value,
const CudaBindInfo& bind_info,
const std::string& block_name) {
ir::Expr copy_for_upper_bound = ir::ir_utils::IRCopy(indice_value);
ir::Expr copy_for_lower_bound = ir::ir_utils::IRCopy(indice_value);
std::set<ir::Expr> var_set = ir::ir_utils::CollectIRNodesWithoutTensor(
indice_value, [](const ir::Expr* x) { return x->as_var(); });
for (ir::Expr var : var_set) {
std::string name = var.as_var_ref()->name;
CHECK(for_map.find(block_name) != for_map.end());
CHECK(for_map[block_name].find(name) != for_map[block_name].end());
ir::Expr for_expr = for_map[block_name][name];
if (bind_info == CudaBindInfo::kCudaBlock) {
if (for_expr.As<ir::For>()->is_gpu_block_binded()) {
optim::ReplaceVarWithExpr(&copy_for_upper_bound,
var.as_var_ref(),
for_expr.As<ir::For>()->min +
for_expr.As<ir::For>()->extent -
Expr(1));
optim::ReplaceVarWithExpr(&copy_for_lower_bound,
var.as_var_ref(),
for_expr.As<ir::For>()->min);
} else {
optim::ReplaceVarWithExpr(
&copy_for_upper_bound, var.as_var_ref(), ir::Expr(0));
optim::ReplaceVarWithExpr(
&copy_for_lower_bound, var.as_var_ref(), ir::Expr(0));
}
} else if (bind_info == CudaBindInfo::kCudaThread) {
if (for_expr.As<ir::For>()->is_gpu_thread_binded()) {
optim::ReplaceVarWithExpr(&copy_for_upper_bound,
var.as_var_ref(),
for_expr.As<ir::For>()->min +
for_expr.As<ir::For>()->extent -
Expr(1));
optim::ReplaceVarWithExpr(&copy_for_lower_bound,
var.as_var_ref(),
for_expr.As<ir::For>()->min);
} else {
optim::ReplaceVarWithExpr(
&copy_for_upper_bound, var.as_var_ref(), ir::Expr(0));
optim::ReplaceVarWithExpr(
&copy_for_lower_bound, var.as_var_ref(), ir::Expr(0));
}
} else if (bind_info == CudaBindInfo::kSerial) {
if (!for_expr.As<ir::For>()->is_gpu_thread_binded() &&
!for_expr.As<ir::For>()->is_gpu_block_binded()) {
optim::ReplaceVarWithExpr(&copy_for_upper_bound,
var.as_var_ref(),
for_expr.As<ir::For>()->min +
for_expr.As<ir::For>()->extent -
Expr(1));
optim::ReplaceVarWithExpr(&copy_for_lower_bound,
var.as_var_ref(),
for_expr.As<ir::For>()->min);
} else {
optim::ReplaceVarWithExpr(
&copy_for_upper_bound, var.as_var_ref(), ir::Expr(0));
optim::ReplaceVarWithExpr(
&copy_for_lower_bound, var.as_var_ref(), ir::Expr(0));
}
} else if (bind_info == CudaBindInfo::kCudaThreadAndSerial) {
if (!for_expr.As<ir::For>()->is_gpu_block_binded()) {
optim::ReplaceVarWithExpr(&copy_for_upper_bound,
var.as_var_ref(),
for_expr.As<ir::For>()->min +
for_expr.As<ir::For>()->extent -
Expr(1));
optim::ReplaceVarWithExpr(&copy_for_lower_bound,
var.as_var_ref(),
for_expr.As<ir::For>()->min);
} else {
optim::ReplaceVarWithExpr(
&copy_for_upper_bound, var.as_var_ref(), ir::Expr(0));
optim::ReplaceVarWithExpr(
&copy_for_lower_bound, var.as_var_ref(), ir::Expr(0));
}
}
}
VLOG(6) << "lower_bound before simplify of " << indice_value << " = "
<< copy_for_lower_bound;
copy_for_lower_bound =
common::AutoSimplify(common::AutoSimplify(copy_for_lower_bound));
VLOG(6) << "upper_bound before simplify of " << indice_value << " = "
<< copy_for_upper_bound;
copy_for_upper_bound =
common::AutoSimplify(common::AutoSimplify(copy_for_upper_bound));
VLOG(6) << "lower_bound of " << indice_value << " = "
<< copy_for_lower_bound;
VLOG(6) << "upper_bound of " << indice_value << " = "
<< copy_for_upper_bound;
return Range{static_cast<int>(copy_for_lower_bound.get_constant()),
static_cast<int>(copy_for_upper_bound.get_constant())};
};
// function to calculate the coefficient and range of the specified for_type
// in a indice expression
auto GetCoefficientAndRange = [&for_map](ir::Expr indice_value,
const ir::ForType& for_type,
const std::string& block_name) {
std::vector<std::pair<int, Range>> coef_and_ranges(3);
std::vector<ir::Expr> indice_copies;
for (int i = 0; i < 3; ++i) {
indice_copies.push_back(ir::ir_utils::IRCopy(indice_value));
}
std::set<ir::Expr> var_set = ir::ir_utils::CollectIRNodesWithoutTensor(
indice_value, [](const ir::Expr* x) { return x->as_var(); });
std::unordered_set<std::string> visited_var_names;
for (ir::Expr var : var_set) {
std::string name = var.as_var_ref()->name;
if (visited_var_names.count(name) > 0) {
continue;
}
visited_var_names.insert(name);
CHECK(for_map.find(block_name) != for_map.end());
CHECK(for_map[block_name].find(name) != for_map[block_name].end());
ir::Expr for_expr = for_map[block_name][name];
for (int i = 0; i < 3; ++i) {
if (for_type == for_expr.As<ir::For>()->for_type() &&
for_expr.As<ir::For>()->bind_info().offset == i &&
for_expr.As<ir::For>()->extent.get_constant() > 1) {
optim::ReplaceVarWithExpr(
&(indice_copies[i]), var.as_var_ref(), ir::Expr(1));
coef_and_ranges[i].second.min =
for_expr.As<ir::For>()->min.get_constant();
coef_and_ranges[i].second.max =
for_expr.As<ir::For>()->min.get_constant() +
for_expr.As<ir::For>()->extent.get_constant();
} else {
optim::ReplaceVarWithExpr(
&(indice_copies[i]), var.as_var_ref(), ir::Expr(0));
}
}
}
for (int i = 0; i < 3; ++i) {
VLOG(6) << "before simplify [" << i << "], the coefficient of "
<< indice_value << " = " << indice_copies[i] << ", range = ("
<< coef_and_ranges[i].second.min << ", "
<< coef_and_ranges[i].second.max << ")";
indice_copies[i] = common::AutoSimplify(indice_copies[i]);
VLOG(6) << "after simplify [" << i << "], the coefficient of "
<< indice_value << " = " << indice_copies << ", range = ("
<< coef_and_ranges[i].second.min << ", "
<< coef_and_ranges[i].second.max << ")";
coef_and_ranges[i].first =
static_cast<int>(indice_copies[i].get_constant());
}
return coef_and_ranges;
};
// Determine whether the indice of a pair of Store and Load cross CUDA threads
auto IsCrossThread = [&](ir::Expr store_indice_value,
ir::Expr load_indice_value,
const std::string& store_block_name,
const std::string& load_block_name) {
Range store_thread_overall_range = CalculateRange(
store_indice_value, CudaBindInfo::kCudaThread, store_block_name);
Range load_thread_overall_range = CalculateRange(
load_indice_value, CudaBindInfo::kCudaThread, load_block_name);
Range store_serial_overall_range = CalculateRange(
store_indice_value, CudaBindInfo::kSerial, store_block_name);
Range load_serial_overall_range = CalculateRange(
load_indice_value, CudaBindInfo::kSerial, load_block_name);
auto store_thread_coefficient_and_range = GetCoefficientAndRange(
store_indice_value, ir::ForType::GPUThread, store_block_name);
auto load_thread_coefficient_and_range = GetCoefficientAndRange(
load_indice_value, ir::ForType::GPUThread, load_block_name);
VLOG(6) << "store_block_name: " << store_block_name
<< ", load_block_name: " << load_block_name;
VLOG(6) << "store_indice_value: " << store_indice_value
<< ", load_indice_value: " << load_indice_value;
VLOG(6) << "store_thread_overall_range = " << store_thread_overall_range;
VLOG(6) << "load_thread_overall_range = " << load_thread_overall_range;
VLOG(6) << "store_serial_overall_range = " << store_serial_overall_range;
VLOG(6) << "load_serial_overall_range = " << load_serial_overall_range;
VLOG(6) << "store_thread_coefficient_and_range[0] = <"
<< store_thread_coefficient_and_range[0].first << ", "
<< store_thread_coefficient_and_range[0].second << ">";
VLOG(6) << "load_thread_coefficient_and_range[0] = <"
<< load_thread_coefficient_and_range[0].first << ", "
<< load_thread_coefficient_and_range[0].second << ">";
VLOG(6) << "store_thread_coefficient_and_range[1] = <"
<< store_thread_coefficient_and_range[1].first << ", "
<< store_thread_coefficient_and_range[1].second << ">";
VLOG(6) << "load_thread_coefficient_and_range[1] = <"
<< load_thread_coefficient_and_range[1].first << ", "
<< load_thread_coefficient_and_range[1].second << ">";
VLOG(6) << "store_thread_coefficient_and_range[2] = <"
<< store_thread_coefficient_and_range[2].first << ", "
<< store_thread_coefficient_and_range[2].second << ">";
VLOG(6) << "load_thread_coefficient_and_range[2] = <"
<< load_thread_coefficient_and_range[2].first << ", "
<< load_thread_coefficient_and_range[2].second << ">";
return !(store_thread_overall_range.min <= load_thread_overall_range.min &&
store_thread_overall_range.max >= load_thread_overall_range.max &&
store_serial_overall_range.min <= load_serial_overall_range.min &&
store_serial_overall_range.max >= load_serial_overall_range.max &&
(store_thread_coefficient_and_range[0].first ==
load_thread_coefficient_and_range[0].first ||
load_thread_coefficient_and_range[0].first == 0) &&
store_thread_coefficient_and_range[0].second.min <=
load_thread_coefficient_and_range[0].second.min &&
store_thread_coefficient_and_range[0].second.max >=
load_thread_coefficient_and_range[0].second.max &&
(store_thread_coefficient_and_range[1].first ==
load_thread_coefficient_and_range[1].first ||
load_thread_coefficient_and_range[1].first == 0) &&
store_thread_coefficient_and_range[1].second.min <=
load_thread_coefficient_and_range[1].second.min &&
store_thread_coefficient_and_range[1].second.max >=
load_thread_coefficient_and_range[1].second.max &&
(store_thread_coefficient_and_range[2].first ==
load_thread_coefficient_and_range[2].first ||
load_thread_coefficient_and_range[2].first == 0) &&
store_thread_coefficient_and_range[2].second.min <=
load_thread_coefficient_and_range[2].second.min &&
store_thread_coefficient_and_range[2].second.max >=
load_thread_coefficient_and_range[2].second.max);
};
// Determine whether the indice of a pair of Store and Load cross CUDA block
auto IsCrossBlock = [&](ir::Expr store_indice_value,
ir::Expr load_indice_value,
const std::string& store_block_name,
const std::string& load_block_name) {
Range store_block_overall_range = CalculateRange(
store_indice_value, CudaBindInfo::kCudaBlock, store_block_name);
Range load_block_overall_range = CalculateRange(
load_indice_value, CudaBindInfo::kCudaBlock, load_block_name);
Range store_thread_and_serial_overall_range =
CalculateRange(store_indice_value,
CudaBindInfo::kCudaThreadAndSerial,
store_block_name);
Range load_thread_and_serial_overall_range = CalculateRange(
load_indice_value, CudaBindInfo::kCudaThreadAndSerial, load_block_name);
auto store_block_coefficient_and_range = GetCoefficientAndRange(
store_indice_value, ir::ForType::GPUBlock, store_block_name);
auto load_block_coefficient_and_range = GetCoefficientAndRange(
load_indice_value, ir::ForType::GPUBlock, load_block_name);
VLOG(6) << "store_block_name: " << store_block_name
<< ", load_block_name: " << load_block_name;
VLOG(6) << "store_indice_value: " << store_indice_value
<< ", load_indice_value: " << load_indice_value;
VLOG(6) << "store_block_overall_range = " << store_block_overall_range;
VLOG(6) << "load_block_overall_range = " << load_block_overall_range;
VLOG(6) << "store_thread_and_serial_overall_range = "
<< store_thread_and_serial_overall_range;
VLOG(6) << "load_thread_and_serial_overall_range = "
<< load_thread_and_serial_overall_range;
VLOG(6) << "store_block_coefficient_and_range[0] = <"
<< store_block_coefficient_and_range[0].first << ", "
<< store_block_coefficient_and_range[0].second << ">";
VLOG(6) << "load_block_coefficient_and_range[0] = <"
<< load_block_coefficient_and_range[0].first << ", "
<< load_block_coefficient_and_range[0].second << ">";
VLOG(6) << "store_block_coefficient_and_range[1] = <"
<< store_block_coefficient_and_range[1].first << ", "
<< store_block_coefficient_and_range[1].second << ">";
VLOG(6) << "load_block_coefficient_and_range[1] = <"
<< load_block_coefficient_and_range[1].first << ", "
<< load_block_coefficient_and_range[1].second << ">";
VLOG(6) << "store_block_coefficient_and_range[2] = <"
<< store_block_coefficient_and_range[2].first << ", "
<< store_block_coefficient_and_range[2].second << ">";
VLOG(6) << "load_block_coefficient_and_range[2] = <"
<< load_block_coefficient_and_range[2].first << ", "
<< load_block_coefficient_and_range[2].second << ">";
return !(store_block_overall_range.min <= load_block_overall_range.min &&
store_block_overall_range.max >= load_block_overall_range.max &&
store_thread_and_serial_overall_range.min <=
load_thread_and_serial_overall_range.min &&
store_thread_and_serial_overall_range.max >=
load_thread_and_serial_overall_range.max &&
(store_block_coefficient_and_range[0].first ==
load_block_coefficient_and_range[0].first ||
load_block_coefficient_and_range[0].first == 0) &&
store_block_coefficient_and_range[0].second.min <=
load_block_coefficient_and_range[0].second.min &&
store_block_coefficient_and_range[0].second.max >=
load_block_coefficient_and_range[0].second.max &&
(store_block_coefficient_and_range[1].first ==
load_block_coefficient_and_range[1].first ||
load_block_coefficient_and_range[1].first == 0) &&
store_block_coefficient_and_range[1].second.min <=
load_block_coefficient_and_range[1].second.min &&
store_block_coefficient_and_range[1].second.max >=
load_block_coefficient_and_range[1].second.max &&
(store_block_coefficient_and_range[2].first ==
load_block_coefficient_and_range[2].first ||
load_block_coefficient_and_range[2].first == 0) &&
store_block_coefficient_and_range[2].second.min <=
load_block_coefficient_and_range[2].second.min &&
store_block_coefficient_and_range[2].second.max >=
load_block_coefficient_and_range[2].second.max);
};
// function to set storage of each tensor
auto SetStorage = [&](ir::ScheduleBlockNode* node) {
if (IsProhibitScheduleExternCallBlock(node->Block())) {
return;
}
ir::MemoryType memory_type = ir::MemoryType::GPULocal;
ir::Expr cur_block = node->Block();
ir::Expr root_block = ir_sch_->GetRootBlock(cur_block);
UpdateVarNameToForMap(root_block);
std::vector<ir::Expr> consumer_blocks =
ir::GetConsumers(cur_block, root_block);
// find store and corresponding load nodes
ir::Expr find_store =
*ir::ir_utils::CollectIRNodesWithoutTensor(
cur_block,
[&](const ir::Expr* x) { return x->As<ir::Store>(); },
true)
.begin();
ir::Expr store_indice_value = AnalyzeIndiceValue(find_store, cur_block);
std::vector<std::tuple<ir::Expr, ir::Expr>> loads_and_blocks;
for (const ir::Expr& consumer_block : consumer_blocks) {
ir::ir_utils::CollectIRNodesWithoutTensor(
consumer_block, [&](const Expr* x) {
if (x->As<ir::Load>() && (x->As<ir::Load>()->name() ==
find_store.As<ir::Store>()->name())) {
loads_and_blocks.push_back(std::make_tuple(*x, consumer_block));
}
return false;
});
}
// Traverse load nodes to check if there are loads that cross cuda blocks or
// threads
for (const auto& load_and_block : loads_and_blocks) {
ir::Expr load = std::get<0>(load_and_block);
ir::Expr consumer_block = std::get<1>(load_and_block);
std::string consumer_block_name =
consumer_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
ir::Expr load_indice_value = AnalyzeIndiceValue(load, consumer_block);
if (IsCrossBlock(store_indice_value,
load_indice_value,
node->id(),
consumer_block_name)) {
// TODO(BiynXu): Return error information to the front-end instead of
// terminating the program.
LOG(FATAL) << "Fusion requires synchronization across blocks, but "
"currently we do not support it.";
break;
} else if (IsCrossThread(store_indice_value,
load_indice_value,
node->id(),
consumer_block_name)) {
memory_type = ir::MemoryType::GPUShared;
}
}
// Set output node to global
std::unordered_set<std::string> output_names = OutputTensorNames();
if (output_names.count(node->id()) > 0) {
memory_type = ir::MemoryType::Auto;
}
// Set the reduce_init tensor and the real tensor to the same memory
if (ir::IsReduceInitTensorName(node->id())) {
ir::Expr block =
ir_sch_->GetBlock(ir::GetOriginalReduceTensorName(node->id()));
memory_type = ir::GetTensor(block)->buffer->memory_type;
}
// Do schedule
if (memory_type == ir::MemoryType::Auto) {
VLOG(6) << "Set store tensor of block " << node->id() << " to global";
} else if (memory_type == ir::MemoryType::GPUShared) {
VLOG(6) << "Set store tensor of block " << node->id() << " to shared";
ir_sch_->SetBuffer(cur_block, "shared");
std::vector<ir::Expr> loops = ir_sch_->GetLoops(cur_block);
if (sync_mark.count(ir::GetOriginalReduceTensorName(node->id())) == 0) {
ir_sch_->SyncThreads(loops.back(), true);
sync_mark.insert(ir::GetOriginalReduceTensorName(node->id()));
}
} else if (memory_type == ir::MemoryType::GPULocal) {
VLOG(6) << "Set store tensor of block " << node->id() << " to register";
ir_sch_->SetBuffer(cur_block, "local");
}
};
schedule_block_graph_->DFSTopoWalk(SetStorage);
VLOG(5) << "[After AllocateStorage] func body: "
<< ir_sch_->GetModule().GetExprs().front();
}
void StaticShapeGroupScheduler::OptimizeReduction() {
VLOG(5) << "[Start OptimizeReduction] func body: "
<< ir_sch_->GetModule().GetExprs().front();
auto_schedule::ReductionFactoring rf(target_);
auto ReductionFactoring = [&](ir::ScheduleBlockNode* node) {
if (IsProhibitScheduleExternCallBlock(node->Block())) {
return;
}
VLOG(6) << "try ReductionFactoring on: " << node->id()
<< ", before ReductionFactoring, func body: "
<< ir_sch_->GetModule().GetExprs().front();
rf.Apply(node->id(), ir_sch_);
VLOG(6) << "try ReductionFactoring on: " << node->id()
<< ", after ReductionFactoring, func body: "
<< ir_sch_->GetModule().GetExprs().front();
};
schedule_block_graph_->DFSTopoWalk(ReductionFactoring);
schedule_block_graph_->Update(*ir_sch_);
VLOG(5) << "[After OptimizeReduction] func body: "
<< ir_sch_->GetModule().GetExprs().front();
}
void StaticShapeGroupScheduler::UpdateBlockOrder() {
ir::Expr root_block = ir_sch_->GetRootBlock(ir_sch_->GetAllBlocks()[0]);
ir::BlockOrderConstructor block_order_constructor;
blocks_order_with_ctrl_stmt_ = block_order_constructor(&root_block);
}
bool StaticShapeGroupScheduler::IsKeepGraphDependency(Expr schedule_block,
Expr target_loop,
int insert_pos) const {
// Assuming inserting the schedule_block into the target_loop,
// obtain the transformed upstream and downstream blocks.
std::unordered_set<std::string> blocks_above;
std::unordered_set<std::string> blocks_below;
bool is_below = false;
bool find_target_loop = false;
int pos_count = -1;
std::map<std::vector<int>, ir::Expr>::const_iterator iter;
for (iter = blocks_order_with_ctrl_stmt_.begin();
iter != blocks_order_with_ctrl_stmt_.end();
++iter) {
if (iter->second.get() == schedule_block.get()) {
continue;
}
if (iter->second.get() == target_loop.get()) {
find_target_loop = true;
}
if (find_target_loop) {
++pos_count;
}
if (pos_count == insert_pos) {
is_below = true;
}
if (iter->second.As<ir::ScheduleBlockRealize>()) {
std::string block_id = iter->second.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
if (is_below) {
blocks_below.insert(block_id);
} else {
blocks_above.insert(block_id);
}
}
}
// Obtain real upstream and downstream nodes
std::string src_id = schedule_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
ir::ScheduleBlockNode* node = schedule_block_graph_->RetrieveNode(src_id);
std::unordered_set<std::string> upstream_ids = node->UpstreamNodes();
std::unordered_set<std::string> downstream_ids = node->DownstreamNodes();
// Check that the transformed upstream and downstream blocks
// still meet the relationship between the
// original upstream and downstream nodes.
for (const std::string& id : upstream_ids) {
if (blocks_above.count(id) == 0) {
VLOG(6) << "[Breaking Graph Level Dependency] ScheduleBlock: " << src_id
<< " cannot be insert into target loop at insert_pos: "
<< insert_pos << " because its upstream block: " << id
<< " will appear downstream.";
VLOG(6) << "The target loop:\n" << target_loop;
return false;
}
}
for (const std::string& id : downstream_ids) {
if (blocks_below.count(id) == 0) {
VLOG(6) << "[Breaking Graph Level Dependency] ScheduleBlock: " << src_id
<< " cannot be insert into target loop at insert_pos: "
<< insert_pos << " because its downstream block: " << id
<< " will appear upstream.";
VLOG(6) << "The target loop:\n" << target_loop;
return false;
}
}
VLOG(6) << "[Meet Graph Level Dependency] ScheduleBlock: " << src_id
<< " can be insert into target loop at insert_pos: " << insert_pos;
VLOG(6) << "The target loop:\n" << target_loop;
return true;
}
bool StaticShapeGroupScheduler::MeetConditions(Expr schedule_block,
Expr target_loop,
int insert_pos) const {
for (const auto& condition_func : feasible_conditions_) {
if (!(this->*condition_func)(schedule_block, target_loop, insert_pos)) {
return false;
}
}
return true;
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h"
namespace cinn {
namespace ir {
// The priority of the ScheduleBlockNode,
// prioritizing whether it has been bound to the cuda axis,
// and secondly considering the amount of calculated data.
struct NodePriority {
bool has_loop_binded;
double score;
bool operator<(const NodePriority& other) const {
if (has_loop_binded ^ other.has_loop_binded) {
return !has_loop_binded;
} else {
return score < other.score;
}
}
};
/**
* The class used for scheduling fusion groups with static shape.
* Its responsibility is to perform loop alignment,
* automatic inline, automatic loop fusion,
* and optimize the storage location of intermediate variables.
* Note: Currently only CUDA backend is supported.
*/
class StaticShapeGroupScheduler : public GroupScheduler {
public:
StaticShapeGroupScheduler(
ir::IRSchedule* ir_sch,
const std::unordered_set<std::string>& output_tensor_names,
const common::Target& target)
: GroupScheduler(ir_sch, output_tensor_names, target) {}
void Schedule() override;
void MapExprSchedule();
std::vector<std::pair<SymbolicPredicate, ir::Expr>> GetIRs() override;
private:
// Automatically align loops for each ScheduleBlock.
void DoLoopAlignment();
// Automatically inline some ScheduleBlock which meets the conditions.
void DoComputeInline();
// Make every effort to automatically merge the loops of the horizontal
// relationship ScheduleBlockNode.
void DoHorizontalLoopFusion();
// Make every effort to automatically merge the loops of the vertical
// relationship ScheduleBlockNode.
void DoVerticalLoopFusion();
// Automatically bind cuda axis on loops.
void BindCudaAxis();
// Automatically allocate storage locations for variables to optimize IO.
void AllocateStorage();
// Automatically optimize the reductive calculation
void OptimizeReduction();
// Evaluate the priority of ScheduleBlockNode.
// The node where the performance bottleneck is located
// has a higher priority, while the node with a lower priority
// needs to compromise and align loops with the node with the highest
// priority.
NodePriority CalculateNodePriority(const ir::ScheduleBlockNode* node) const;
// Find the highest priority ScheduleBlockNode,
// other nodes need to align the loop with it.
ir::ScheduleBlockNode* FindGlobalMasterNode() const;
// Obtain the latest order of ScheduleBlock and the control structures
// throughout the entire IR.
void UpdateBlockOrder();
// Get output tensor names of group.
std::unordered_set<std::string> OutputTensorNames() const;
/**
* @brief Determine whether the graph level dependency is still maintained
* after the schedule_block is placed in the insert position of target_loop.
* @param schedule_block The src schedule_block to be replaced.
* @param target_loop The target loop to be insert into the schedule_block.
* @param insert_pos The insert position of new schedule_block in the
* target_loop.
*/
bool IsKeepGraphDependency(Expr schedule_block,
Expr target_loop,
int insert_pos) const;
/**
* @brief Determine whether all feasible conditions are met
* after the schedule_block is placed in the insert position of target_loop.
* @param schedule_block The src schedule_block to be replaced.
* @param target_loop The target loop to be insert into the schedule_block.
* @param insert_pos The insert position of new schedule_block in the
* target_loop.
*/
bool MeetConditions(Expr schedule_block,
Expr target_loop,
int insert_pos) const;
private:
/**
* @brief Interface of feasibility condition.
* @param schedule_block The src schedule_block to be replaced.
* @param target_loop The target loop to be insert into the schedule_block.
* @param insert_pos The insert position of new schedule_block in the
* target_loop.
*/
using FeasibleCondition = bool (StaticShapeGroupScheduler::*)(
Expr schedule_block, Expr target_loop, int insert_pos) const;
// All feasible conditions.
std::vector<FeasibleCondition> feasible_conditions_;
/**
* The order of blocks and their control statements,
* only For, IfThenElse and ScheduleBlock is considered.
*
* Example:
* for0:
* for1:
* block0
* block1
* block2
* for2:
* block3
* block4
*
* the result is:
* [0]: for0
* [0, 0]: for1
* [0, 0, 0]: block0
* [0, 0, 1]: block1
* [0, 1]: block2
* [0, 2]: for2
* [0, 2, 0]: block3
* [0, 2, 1]: block4
*/
std::map<std::vector<int>, ir::Expr> blocks_order_with_ctrl_stmt_;
};
} // namespace ir
} // namespace cinn
......@@ -20,10 +20,10 @@
#include "paddle/cinn/common/cinn_value.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/ir_simplify.h"
namespace cinn {
......@@ -257,7 +257,7 @@ Expr For::Make(Var loop_var,
node->min = min;
node->extent = extent;
node->device_api = device_api;
node->body = body;
node->body = body.As<ir::Block>() ? body : ir::Block::Make({body});
node->set_for_type(for_type);
node->set_vectorize_info(vector_info);
node->set_bind_info(bind_info);
......@@ -346,6 +346,10 @@ std::vector<const Expr *> ScheduleBlockRealize::expr_fields() const {
}
Expr IfThenElse::Make(Expr condition, Expr true_case, Expr false_case) {
if (true_case.defined() && (!true_case.As<Block>()))
true_case = ir::Block::Make({true_case});
if (false_case.defined() && (!false_case.As<Block>()))
false_case = ir::Block::Make({false_case});
auto node = make_shared<IfThenElse>(condition, true_case, false_case);
return Expr(node);
}
......@@ -513,7 +517,7 @@ Expr PolyFor::Make(Var iterator,
n->condition = condition;
n->inc = inc;
n->device_api = device_api;
n->body = body;
n->body = body.As<ir::Block>() ? body : ir::Block::Make({body});
n->set_for_type(for_type);
n->set_vectorize_info(vectorize_info);
n->set_bind_info(bind_info);
......@@ -531,7 +535,7 @@ std::vector<const Expr *> PolyFor::expr_fields() const {
}
Expr PolyFor::ExtractExtent() const {
auto nodes = CollectIRNodes(condition, [&](const Expr *e) {
auto nodes = ir::ir_utils::CollectIRNodes(condition, [&](const Expr *e) {
return e->As<NE>() || //
e->As<EQ>() || //
e->As<Min>() || //
......
......@@ -1002,6 +1002,7 @@ struct _Module_ : public ExprNode<_Module_> {
std::vector<Expr> buffers;
std::vector<Expr> functions;
std::vector<Expr> submodules;
std::vector<Expr> predicates;
static ir::Module Make(const std::string& name, Target target);
......@@ -1011,7 +1012,7 @@ struct _Module_ : public ExprNode<_Module_> {
};
/**
* \brief PrimitiveNode holds the contept of Primitive in CINN.
* \brief PrimitiveNode holds the concept of Primitive in CINN.
* A Primitive is a basic Call to some Expr function, it is introduced to create
* several level of coarsed-grained IR nodes for better IR optimization and
* hardware adaption.
......
core_gather_headers()
gather_srcs(cinnapi_src SRCS ir_analyzer.cc)
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/schedule_base.h"
#include "paddle/cinn/ir/schedule/schedule_desc.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/utils/error.h"
#include "paddle/cinn/utils/random_engine.h"
namespace cinn {
namespace ir {
namespace analyzer {
namespace {
struct FindBlocksVisitor {
explicit FindBlocksVisitor(const std::string& block_name = "")
: block_name_(block_name) {}
std::vector<Expr> operator()(const Expr* expr) {
Visit(expr);
return result;
}
private:
void Visit(const Expr* expr) {
if (!expr->defined()) return;
if (!block_name_.empty() && !result.empty()) return;
if (expr->As<ir::For>()) {
Visit(&(expr->As<ir::For>()->body));
} else if (expr->As<ir::ScheduleBlockRealize>()) {
if (!expr->As<ir::ScheduleBlockRealize>()->iter_values.empty()) {
auto* schedule_block = expr->As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>();
if (block_name_.empty() || schedule_block->name == block_name_) {
result.emplace_back(*expr);
}
} else {
Visit(&(expr->As<ir::ScheduleBlockRealize>()->schedule_block));
}
} else if (expr->As<ir::ScheduleBlock>()) {
Visit(&(expr->As<ir::ScheduleBlock>()->body));
} else if (expr->As<ir::Block>()) {
for (auto& n : expr->As<ir::Block>()->stmts) Visit(&n);
} else if (expr->As<ir::IfThenElse>()) {
Visit(&(expr->As<ir::IfThenElse>()->true_case));
Visit(&(expr->As<ir::IfThenElse>()->false_case));
}
}
std::string block_name_;
std::vector<Expr> result{};
};
struct FindLoopsVisitor {
explicit FindLoopsVisitor(const Expr& block) : block_(block) {}
std::vector<Expr> operator()(const Expr* expr) {
CHECK(block_.As<ir::ScheduleBlockRealize>());
visit_end = false;
Visit(expr);
return result;
}
private:
void Visit(const Expr* expr) {
if (visit_end || !expr->defined()) return;
if (expr->As<ir::For>()) {
father_loops.emplace_back(*expr);
Visit(&(expr->As<ir::For>()->body));
father_loops.pop_back();
} else if (expr->As<ir::ScheduleBlockRealize>()) {
if (!expr->As<ir::ScheduleBlockRealize>()->iter_values.empty() &&
(*expr == block_)) {
result = father_loops;
visit_end = true;
return;
} else {
Visit(&(expr->As<ir::ScheduleBlockRealize>()->schedule_block));
}
} else if (expr->As<ir::ScheduleBlock>()) {
Visit(&(expr->As<ir::ScheduleBlock>()->body));
} else if (expr->As<ir::Block>()) {
for (auto& n : expr->As<ir::Block>()->stmts) Visit(&n);
} else if (expr->As<ir::IfThenElse>()) {
Visit(&(expr->As<ir::IfThenElse>()->true_case));
Visit(&(expr->As<ir::IfThenElse>()->false_case));
}
}
std::vector<Expr> father_loops{};
std::vector<Expr> result{};
bool visit_end{false};
const Expr& block_;
};
struct FindBlockParent : public ir::IRMutator<> {
public:
explicit FindBlockParent(const std::string& block_name)
: block_name_(block_name) {}
void operator()(Expr* expr) { IRMutator::Visit(expr, expr); }
private:
void Visit(const ir::Block* expr, Expr* op) override {
if (target_) return;
for (auto& stmt : expr->stmts) {
if (stmt.As<ir::ScheduleBlockRealize>()) {
if (stmt.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name == block_name_) {
target_ = op;
return;
}
}
}
IRMutator::Visit(expr, op);
}
void Visit(const ir::For* expr, Expr* op) override {
if (target_) return;
if (expr->body.As<ir::ScheduleBlockRealize>()) {
if (expr->body.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name == block_name_) {
target_ = op;
return;
}
}
IRMutator::Visit(expr, op);
}
void Visit(const ir::ScheduleBlock* expr, Expr* op) override {
if (target_) return;
if (expr->body.As<ir::ScheduleBlockRealize>()) {
if (expr->body.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name == block_name_) {
target_ = op;
return;
}
}
IRMutator::Visit(expr, op);
}
std::string block_name_;
public:
ir::Expr* target_{nullptr};
};
} // namespace
bool HasBlock(const std::vector<Expr>& exprs, const std::string& block_name) {
for (auto& it_expr : exprs) {
FindBlocksVisitor visitor(block_name);
auto find_blocks = visitor(&it_expr);
if (!find_blocks.empty()) {
CHECK_EQ(find_blocks.size(), 1U)
<< "There should not be more than 1 block with identical name!";
return true;
}
}
return false;
}
std::vector<Expr> GetLoops(const std::vector<Expr>& exprs,
const std::string& block_name) {
Expr block = GetBlock(exprs, block_name);
std::vector<Expr> result = GetLoops(exprs, block);
return result;
}
std::vector<Expr> GetLoops(const std::vector<Expr>& exprs, const Expr& block) {
std::vector<Expr> result;
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
std::string block_name = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
for (auto& it_expr : exprs) {
FindLoopsVisitor visitor(block);
auto find_loops = visitor(&it_expr);
if (!find_loops.empty()) {
if (!result.empty())
LOG(FATAL) << "Find block with name: \n"
<< block_name << " appeared in more than one AST!";
result = find_loops;
}
}
if (result.empty()) {
result.push_back(AddUnitLoop(exprs, block));
}
return result;
}
std::vector<Expr> GetAllBlocks(const std::vector<Expr>& exprs) {
std::vector<Expr> result;
for (auto& it_expr : exprs) {
FindBlocksVisitor visitor;
auto find_blocks = visitor(&it_expr);
result.insert(result.end(), find_blocks.begin(), find_blocks.end());
}
for (auto& it_expr : exprs) {
VLOG(3) << "it_expr is : " << it_expr;
}
CHECK(!result.empty()) << "Didn't find blocks in expr.";
return result;
}
std::vector<Expr> GetChildBlocks(const Expr& expr) {
CHECK(expr.As<ir::ScheduleBlockRealize>() || expr.As<ir::For>());
FindBlocksVisitor visitor;
std::vector<Expr> result = visitor(&expr);
return result;
}
Expr GetBlock(const std::vector<Expr>& exprs, const std::string& block_name) {
Expr result;
for (auto& it_expr : exprs) {
FindBlocksVisitor visitor(block_name);
auto find_blocks = visitor(&it_expr);
if (!find_blocks.empty()) {
CHECK_EQ(find_blocks.size(), 1U)
<< "There should not be more than 1 block with identical name!";
result = find_blocks[0];
return result;
}
}
LOG(FATAL) << "Didn't find a block with name " << block_name
<< " in this ModuleExpr!";
}
Expr GetRootBlock(const std::vector<Expr>& exprs, const Expr& expr) {
for (auto& it_expr : exprs) {
auto find_expr = ir::ir_utils::CollectIRNodesWithoutTensor(
it_expr,
[&](const Expr* x) {
return x->node_type() == expr.node_type() && *x == expr;
},
true);
if (!find_expr.empty()) {
CHECK(it_expr.As<ir::Block>());
CHECK_EQ(it_expr.As<ir::Block>()->stmts.size(), 1U);
CHECK(it_expr.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>());
return it_expr.As<ir::Block>()->stmts[0];
}
}
LOG(FATAL) << "Didn't find expr \n"
<< expr << "in StScheduleImpl:\n"
<< exprs[0];
}
DeviceAPI GetDeviceAPI(const std::vector<Expr>& exprs) {
auto find_for_nodes = ir::ir_utils::CollectIRNodesWithoutTensor(
exprs.front(), [&](const Expr* x) { return x->As<ir::For>(); }, true);
CHECK(!find_for_nodes.empty());
return (*find_for_nodes.begin()).As<ir::For>()->device_api;
}
Expr AddUnitLoop(const std::vector<Expr>& exprs, const Expr& block) {
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
std::string block_name = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
FindBlockParent visitor(block_name);
for (auto expr : exprs) {
visitor(&expr);
if (visitor.target_) {
break;
}
}
CHECK(visitor.target_) << ", block name : " << block_name << "\n" << exprs;
if (visitor.target_->As<ir::Block>()) {
for (auto& stmt : visitor.target_->As<ir::Block>()->stmts) {
if (stmt.As<ir::ScheduleBlockRealize>()) {
if (stmt.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name == block_name) {
auto block = ir::Block::Make({GetBlock(exprs, block_name)});
auto loop = ir::For::Make(ir::Var(common::UniqName("ix")),
ir::Expr(0),
ir::Expr(1),
ir::ForType::Serial,
ir::DeviceAPI::UNK,
block);
stmt = loop;
return loop;
}
}
}
} else if (visitor.target_->As<ir::For>()) {
auto block = ir::Block::Make({visitor.target_->As<ir::For>()->body});
auto loop = ir::For::Make(ir::Var(common::UniqName("ix")),
ir::Expr(0),
ir::Expr(1),
ir::ForType::Serial,
ir::DeviceAPI::UNK,
block);
visitor.target_->As<ir::For>()->body = loop;
return loop;
} else if (visitor.target_->As<ir::ScheduleBlock>()) {
auto block =
ir::Block::Make({visitor.target_->As<ir::ScheduleBlock>()->body});
auto loop = ir::For::Make(ir::Var(common::UniqName("ix")),
ir::Expr(0),
ir::Expr(1),
ir::ForType::Serial,
ir::DeviceAPI::UNK,
block);
visitor.target_->As<ir::ScheduleBlock>()->body = loop;
return loop;
} else {
LOG(FATAL) << "Can't find block's parent!";
}
LOG(FATAL) << "Shouldn't reach code here in AddUnitLoop";
return Expr{nullptr};
}
} // namespace analyzer
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_mutator.h"
namespace cinn {
namespace ir {
namespace analyzer {
bool HasBlock(const std::vector<Expr>& exprs, const std::string& block_name);
std::vector<Expr> GetLoops(const std::vector<Expr>& exprs,
const std::string& block_name);
std::vector<Expr> GetLoops(const std::vector<Expr>& exprs, const Expr& block);
std::vector<Expr> GetAllBlocks(const std::vector<Expr>& exprs);
std::vector<Expr> GetChildBlocks(const Expr& expr);
Expr GetBlock(const std::vector<Expr>& exprs, const std::string& block_name);
Expr GetRootBlock(const std::vector<Expr>& exprs, const Expr& expr);
DeviceAPI GetDeviceAPI(const std::vector<Expr>& exprs);
Expr AddUnitLoop(const std::vector<Expr>& exprs, const Expr& block);
} // namespace analyzer
} // namespace ir
} // namespace cinn
......@@ -18,10 +18,10 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace cinn {
namespace ir {
......
......@@ -51,6 +51,7 @@ class _BufferRange_;
class BufferRange;
class ScheduleBlock;
class ScheduleBlockRealize;
class Dim;
// clang-format off
#define NODETY_PRIMITIVE_TYPE_FOR_EACH(macro__) \
......@@ -113,6 +114,7 @@ class ScheduleBlockRealize;
macro__(_BufferRange_) \
macro__(ScheduleBlock) \
macro__(ScheduleBlockRealize) \
macro__(_Dim_) \
#define NODETY_FORALL(__m) \
......
......@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace ir {} // namespace ir
......
......@@ -19,7 +19,7 @@
#include "paddle/cinn/ir/intrinsic_ops.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/ir/ir_visitor.h"
namespace cinn {
namespace ir {
......@@ -348,5 +348,12 @@ void IRMutator<T>::Visit(const ScheduleBlockRealize *expr, T op) {
&node->schedule_block);
}
template <typename T>
void IRMutator<T>::Visit(const _Dim_ *expr, T op) {
auto *node = op->template As<_Dim_>();
CHECK(node);
// IRVisitorRequireReImpl<void, T>::Visit(&node->sym_dim, &node->sym_dim);
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <iomanip>
#include <limits>
#include <vector>
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace ir {
using common::bfloat16;
using common::float16;
void IrPrinter::Print(const Expr &e) {
IRVisitorRequireReImpl::Visit(&e);
os_ << str_;
str_ = "";
}
void IrPrinter::Print(const std::vector<Expr> &exprs,
const std::string &splitter) {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
Visit(exprs[i]);
str_ += splitter;
}
if (!exprs.empty()) Visit(exprs.back());
os_ << str_;
str_ = "";
}
void IrPrinter::Visit(const IntImm *x) {
if (x->type().is_int(64)) {
str_ += std::to_string(x->value);
str_ += "ll";
} else if (x->type().is_int(32)) {
str_ += std::to_string(x->value);
} else if (x->type().is_int(16)) {
str_ += "(int16_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_int(8)) {
str_ += "(int8_t)";
str_ += std::to_string(x->value);
} else {
LOG(FATAL) << "Not support int type: " << x->type();
}
}
void IrPrinter::Visit(const UIntImm *x) {
if (x->type().is_uint(64)) {
str_ += std::to_string(x->value);
str_ += "ull";
} else if (x->type().is_uint(32)) {
str_ += std::to_string(x->value);
} else if (x->type().is_uint(16)) {
str_ += "(uint16_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_uint(8)) {
str_ += "(uint8_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_uint(1)) {
if (x->value) {
str_ += "true";
} else {
str_ += "false";
}
} else {
LOG(FATAL) << "Not support uint type: " << x->type();
}
}
void IrPrinter::Visit(const FloatImm *x) {
std::ostringstream ss;
if (x->type().is_float16()) {
if (std::isinf(x->value)) {
ss << "cinn::common::raw_uint16_to_float16(0x7c00)";
} else if (std::isnan(x->value)) {
ss << "cinn::common::raw_uint16_to_float16(0x7e00)";
} else {
ss << "(float16)";
ss << std::setprecision(std::numeric_limits<float16>::max_digits10);
ss << static_cast<float16>(x->value) << "f";
}
} else if (x->type().is_bfloat16()) {
if (std::isinf(x->value)) {
ss << "cinn::common::raw_uint16_to_bfloat16(0x7F80)";
} else if (std::isnan(x->value)) {
ss << "cinn::common::raw_uint16_to_bfloat16(0x7FC0)";
} else {
ss << "(bfloat16)";
ss << std::setprecision(std::numeric_limits<bfloat16>::max_digits10);
ss << static_cast<bfloat16>(x->value) << "f";
}
} else if (x->type().is_float(32)) {
ss << std::setprecision(std::numeric_limits<float>::max_digits10);
ss << std::showpoint;
ss << x->value;
if (std::isfinite(x->value)) {
ss << "f";
}
} else if (x->type().is_float(64)) {
ss << std::setprecision(std::numeric_limits<double>::max_digits10);
ss << std::showpoint;
ss << x->value;
} else {
LOG(FATAL) << "Not support float type: " << x->type();
}
str_ += ss.str();
}
void IrPrinter::Visit(const StringImm *x) {
str_ += "\"";
str_ += x->value;
str_ += "\"";
}
void IrPrinter::Visit(const Add *x) { PrintBinaryOp("+", x); }
void IrPrinter::Visit(const Sub *x) { PrintBinaryOp("-", x); }
void IrPrinter::Visit(const Mul *x) { PrintBinaryOp("*", x); }
void IrPrinter::Visit(const Div *x) { PrintBinaryOp("/", x); }
void IrPrinter::Visit(const Mod *x) { PrintBinaryOp("%", x); }
void IrPrinter::Visit(const EQ *x) { PrintBinaryOp("==", x); }
void IrPrinter::Visit(const NE *x) { PrintBinaryOp("!=", x); }
void IrPrinter::Visit(const LT *x) { PrintBinaryOp("<", x); }
void IrPrinter::Visit(const LE *x) { PrintBinaryOp("<=", x); }
void IrPrinter::Visit(const GT *x) { PrintBinaryOp(">", x); }
void IrPrinter::Visit(const GE *x) { PrintBinaryOp(">=", x); }
void IrPrinter::Visit(const And *x) { PrintBinaryOp("and", x); }
void IrPrinter::Visit(const Or *x) { PrintBinaryOp("or", x); }
void IrPrinter::Visit(const Not *x) {
str_ += "!";
Visit(x->v());
}
void IrPrinter::Visit(const Min *x) {
str_ += "cinn_min(";
Visit(x->a());
str_ += ", ";
Visit(x->b());
str_ += ")";
}
void IrPrinter::Visit(const Max *x) {
str_ += "cinn_max(";
Visit(x->a());
str_ += ", ";
Visit(x->b());
str_ += ")";
}
void IrPrinter::Visit(const Minus *x) {
str_ += "-(";
Visit(x->v());
str_ += ")";
}
void IrPrinter::Visit(const For *x) {
if (x->is_parallel()) {
str_ += "parallel for (";
} else if (x->is_unrolled()) {
str_ += "unroll for (";
} else if (x->is_vectorized()) {
int factor = x->vectorize_info().factor;
str_ += "vectorize[";
str_ += std::to_string(factor);
str_ += "] for (";
} else if (x->is_binded()) {
auto &bind_info = x->bind_info();
if (bind_info.valid()) {
char axis_name = 'x' + bind_info.offset;
auto for_type = bind_info.for_type;
std::string prefix =
for_type == ForType::GPUBlock ? "blockIdx." : "threadIdx.";
str_ += "thread_bind[";
str_ += prefix;
str_ += axis_name;
str_ += "] for (";
} else {
str_ += "thread_bind[invalid info] for (";
}
} else if (x->is_serial()) {
str_ += "serial for (";
} else if (x->is_default()) {
str_ += "default for (";
} else {
str_ += "for (";
}
Visit(x->loop_var);
str_ += ", ";
Visit(x->min);
str_ += ", ";
Visit(x->extent);
str_ += ")\n";
DoIndent();
Visit(x->body);
}
void IrPrinter::Visit(const PolyFor *x) {
if (x->is_parallel()) {
str_ += "parallel poly_for (";
} else {
str_ += "poly_for (";
}
Visit(x->iterator);
str_ += ", ";
Visit(x->init);
str_ += ", ";
Visit(x->condition);
str_ += ", ";
Visit(x->inc);
str_ += ")\n";
DoIndent();
Visit(x->body);
}
void IrPrinter::Visit(const IfThenElse *x) {
str_ += "if (";
Visit(x->condition);
str_ += ") ";
Visit(x->true_case);
if (x->false_case.defined()) {
str_ += " else ";
Visit(x->false_case);
}
}
void IrPrinter::Visit(const Block *x) {
str_ += "{\n";
IncIndent();
for (std::size_t i = 0; !x->stmts.empty() && i + 1 < x->stmts.size(); i++) {
DoIndent();
Visit(x->stmts[i]);
str_ += "\n";
}
if (!x->stmts.empty()) {
DoIndent();
Visit(x->stmts.back());
}
DecIndent();
str_ += "\n";
DoIndent();
str_ += "}";
}
void IrPrinter::Visit(const Call *x) {
str_ += x->name;
str_ += "(";
if (!x->read_args.empty()) {
for (std::size_t i = 0; i + 1 < x->read_args.size(); i++) {
Visit(x->read_args[i]);
str_ += ", ";
}
Visit(x->read_args.back());
}
if (!x->write_args.empty()) {
if (!x->read_args.empty()) str_ += ", ";
for (std::size_t i = 0; i + 1 < x->write_args.size(); i++) {
Visit(x->write_args[i]);
str_ += ", ";
}
Visit(x->write_args.back());
}
str_ += ")";
}
void IrPrinter::Visit(const Cast *x) {
str_ += x->type().to_string();
str_ += "(";
Visit(x->v());
str_ += ")";
}
void IrPrinter::Visit(const _Module_ *x) {}
void IrPrinter::Visit(const _Var_ *x) { str_ += x->name; }
void IrPrinter::Visit(const Alloc *x) {
auto *buffer = x->destination.As<ir::_Buffer_>();
CHECK(buffer);
str_ += "alloc(";
str_ += buffer->name;
str_ += ", ";
Visit(x->extents);
str_ += ")";
}
void IrPrinter::Visit(const Select *x) {
str_ += "select(";
Visit(x->condition);
str_ += ", ";
Visit(x->true_value);
str_ += ", ";
Visit(x->false_value);
str_ += ")";
}
void IrPrinter::Visit(const Load *x) {
if (x->is_addr_tensor()) {
auto *tensor = x->tensor.As<ir::_Tensor_>();
CHECK(tensor);
str_ += tensor->name;
} else if (x->is_addr_scalar()) {
Visit(x->tensor);
} else {
CINN_NOT_IMPLEMENTED
}
str_ += "[";
for (std::size_t i = 0; i + 1 < x->indices.size(); i++) {
Visit(x->indices[i]);
str_ += ", ";
}
if (!x->indices.empty()) Visit(x->indices.back());
str_ += "]";
}
void IrPrinter::Visit(const Store *x) {
if (x->is_addr_tensor()) {
auto *tensor_node = x->tensor.As<ir::_Tensor_>();
CHECK(tensor_node);
str_ += tensor_node->name;
} else if (x->is_addr_scalar()) {
Visit(x->tensor);
} else {
CINN_NOT_IMPLEMENTED
}
str_ += "[";
for (std::size_t i = 0; i + 1 < x->indices.size(); i++) {
Visit(x->indices[i]);
str_ += ", ";
}
if (!x->indices.empty()) Visit(x->indices.back());
str_ += "] = ";
Visit(x->value);
}
void IrPrinter::Visit(const Free *x) {
auto *buffer = x->destination.As<ir::_Buffer_>();
CHECK(buffer);
str_ += "free(";
str_ += buffer->name;
str_ += ")";
}
void IrPrinter::DoIndent() { str_ += std::string(indent_, ' '); }
void IrPrinter::IncIndent() { indent_ += indent_unit; }
void IrPrinter::DecIndent() { indent_ -= indent_unit; }
void IrPrinter::Visit(const _Buffer_ *x) {
std::vector<std::string> dim_names;
std::transform(x->shape.begin(),
x->shape.end(),
std::back_inserter(dim_names),
[&](const Expr &x) { return utils::GetStreamCnt(x); });
str_ += "_Buffer_<";
str_ += x->type().to_string();
str_ += ": ";
str_ += utils::Join(dim_names, ",");
str_ += ">(";
str_ += x->name;
str_ += ")";
}
void IrPrinter::Visit(const _Tensor_ *x) {
str_ += "Tensor(";
str_ += x->name;
str_ += ", ";
str_ += "[";
if (!x->shape.empty()) {
for (std::size_t i = 0; i + 1 < x->shape.size(); i++) {
Visit(x->shape[i]);
str_ += ",";
}
Visit(x->shape.back());
}
str_ += "])";
}
void IrPrinter::Visit(const _LoweredFunc_ *f) {
str_ += "function ";
str_ += f->name;
str_ += " ";
std::vector<std::string> arg_names;
for (auto &arg : f->args) {
arg_names.push_back(arg.name());
}
str_ += "(";
str_ += utils::Join(arg_names, ", ");
str_ += ")\n";
Visit(f->body);
}
void IrPrinter::Visit(const Let *f) {
CHECK(f->type().valid());
str_ += f->type().to_string();
str_ += " ";
Visit(f->symbol);
if (f->body.defined()) {
str_ += " = ";
Visit(f->body);
}
}
void IrPrinter::Visit(const Reduce *f) {
str_ += "Reduce(";
switch (f->reduce_type) {
case Reduce::ReduceType::kSum:
str_ += "sum";
break;
case Reduce::ReduceType::kSub:
str_ += "sub";
break;
case Reduce::ReduceType::kDiv:
str_ += "Div";
break;
case Reduce::ReduceType::kMul:
str_ += "Mul";
break;
case Reduce::ReduceType::kMax:
str_ += "Max";
break;
case Reduce::ReduceType::kMin:
str_ += "Min";
break;
case Reduce::ReduceType::kAll:
str_ += "&&";
break;
case Reduce::ReduceType::kAny:
str_ += "||";
break;
}
str_ += ", ";
Visit(f->body);
str_ += ",";
Visit(f->init);
str_ += ")";
}
void IrPrinter::Visit(const Ramp *x) {
str_ += "Ramp(";
Visit(x->base);
str_ += ",";
Visit(x->stride);
str_ += ",";
str_ += std::to_string(x->lanes);
str_ += ")";
}
void IrPrinter::Visit(const Broadcast *x) {
str_ += "Broadcast(";
Visit(x->value);
str_ += ",";
str_ += std::to_string(x->lanes);
str_ += ")";
}
void IrPrinter::Visit(const FracOp *x) {
str_ += "(";
Visit(x->a());
str_ += " / ";
Visit(x->b());
str_ += ")";
}
void IrPrinter::Visit(const Product *x) {
str_ += "(";
for (std::size_t i = 0; i + 1 < x->operands().size(); i++) {
Visit(x->operand(i));
str_ += " * ";
}
if (!x->operands().empty()) Visit(x->operands().back());
str_ += ")";
}
void IrPrinter::Visit(const Sum *x) {
str_ += "(";
for (std::size_t i = 0; i + 1 < x->operands().size(); i++) {
Visit(x->operand(i));
str_ += " + ";
}
if (!x->operands().empty()) Visit(x->operands().back());
str_ += ")";
}
void IrPrinter::Visit(const PrimitiveNode *x) {
str_ += x->name;
str_ += "(";
std::vector<std::string> args_repr;
for (auto &args : x->arguments) {
std::vector<std::string> arg_repr;
for (auto &arg : args) {
arg_repr.push_back(utils::GetStreamCnt(arg));
}
args_repr.push_back(utils::Join(arg_repr, ","));
}
str_ += utils::Join(args_repr, ",");
str_ += ")";
}
void IrPrinter::Visit(const _BufferRange_ *x) {
auto *buffer = x->buffer.As<ir::_Buffer_>();
CHECK(buffer);
str_ += buffer->name;
str_ += "[";
for (std::size_t i = 0; i < x->ranges.size(); i++) {
if (i) str_ += ", ";
auto &range = x->ranges[i];
str_ += range->name;
str_ += "(";
if (range->lower_bound.defined()) {
Visit(range->lower_bound);
str_ += ":";
} else {
str_ += "undefined:";
}
if (range->upper_bound.defined()) {
Visit(range->upper_bound);
} else {
str_ += "undefined";
}
str_ += ")";
}
str_ += "]";
}
void IrPrinter::Visit(const ScheduleBlock *x) {}
void IrPrinter::Visit(const ScheduleBlockRealize *x) {
auto *schedule_block = x->schedule_block.As<ScheduleBlock>();
str_ += "ScheduleBlock(";
str_ += schedule_block->name;
str_ += ")\n";
DoIndent();
str_ += "{\n";
// print block vars and bindings
auto iter_vars = schedule_block->iter_vars;
auto iter_values = x->iter_values;
CHECK_EQ(iter_vars.size(), iter_values.size());
IncIndent();
if (!iter_vars.empty()) DoIndent();
for (std::size_t i = 0; i < iter_vars.size(); i++) {
if (i) str_ += ", ";
str_ += iter_vars[i]->name;
}
if (!iter_vars.empty()) str_ += " = axis.bind(";
for (std::size_t i = 0; i < iter_values.size(); i++) {
if (i) str_ += ", ";
Visit(iter_values[i]);
}
if (!iter_vars.empty()) str_ += ")\n";
// print block body
if (!schedule_block->read_buffers.empty()) {
DoIndent();
str_ += "read_buffers(";
auto &read_buffers = schedule_block->read_buffers;
for (std::size_t i = 0; i < read_buffers.size(); i++) {
if (i) str_ += ", ";
Visit(read_buffers[i]);
}
str_ += ")\n";
}
if (!schedule_block->write_buffers.empty()) {
DoIndent();
str_ += "write_buffers(";
auto &write_buffers = schedule_block->write_buffers;
for (std::size_t i = 0; i < write_buffers.size(); i++) {
if (i) str_ += ", ";
Visit(write_buffers[i]);
}
str_ += ")\n";
}
if (!schedule_block->attrs.empty()) {
DoIndent();
str_ += "attrs(";
bool comma = false;
for (auto &&kv : schedule_block->attrs) {
if (comma) str_ += ", ";
str_ += kv.first;
str_ += ":";
absl::visit(
[this](auto &&arg) {
std::ostringstream ss;
ss << arg;
this->str_ += ss.str();
},
kv.second);
comma = true;
}
str_ += ")\n";
}
DoIndent();
Visit(schedule_block->body);
str_ += "\n";
DecIndent();
DoIndent();
str_ += "}";
}
void IrPrinter::Visit(const _Dim_ *x) {
str_ += "Dim(name: ";
str_ += x->name;
str_ += ", sym_name: ";
str_ += x->GetSymbolName();
str_ += ", dim_size: ";
str_ += std::to_string(x->GetRealDimSize());
str_ += ")";
}
void IrPrinter::Visit(const IntrinsicOp *x) {
switch (x->getKind()) {
#define __(op__) \
case IntrinsicKind::k##op__: \
Visit(llvm::dyn_cast<intrinsics::op__>(x)); \
break;
INTRINSIC_KIND_FOR_EACH(__)
#undef __
}
}
void IrPrinter::Visit(const intrinsics::BufferGetDataHandle *x) {
str_ += runtime::intrinsic::buffer_get_data_handle;
Visit(x->buffer);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::BufferGetDataConstHandle *x) {
str_ += runtime::intrinsic::buffer_get_data_const_handle;
Visit(x->buffer);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::PodValueToX *x) {
str_ += "pod_value_to_";
str_ += x->GetOutputType(0).to_string();
str_ += "(";
Visit(x->pod_value_ptr);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::BufferCreate *x) {
str_ += runtime::intrinsic::buffer_create;
str_ += "()";
}
void IrPrinter::Visit(const intrinsics::GetAddr *x) {
str_ += "get_addr(";
Visit(x->data);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::ArgsConstruct *x) {
str_ += runtime::intrinsic::args_construct_repr;
str_ += "(";
Visit(std::vector<Expr>(x->args.begin(), x->args.end()));
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::BuiltinIntrin *x) {
str_ += runtime::intrinsic::builtin_intrin_repr;
str_ += "_";
str_ += x->name;
str_ += "(";
if (!x->args.empty()) {
for (std::size_t i = 0; i + 1 < x->args.size(); i++) {
Visit(x->args[i]);
str_ += ", ";
}
Visit(x->args.back());
}
str_ += ")";
}
std::ostream &operator<<(std::ostream &os, Expr a) {
std::stringstream ss;
IrPrinter printer(ss);
printer.Print(a);
os << ss.str();
return os;
}
std::ostream &operator<<(std::ostream &os, const std::vector<Expr> &a) {
std::stringstream ss;
IrPrinter printer(ss);
printer.Print(a);
os << ss.str();
return os;
}
std::ostream &operator<<(std::ostream &os, const ir::Module &m) {
os << "Module " << m->name << " {\n\n";
for (auto &fn : m->functions) {
os << fn << '\n';
}
os << "\n\n}";
return os;
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_visitor.h"
namespace cinn {
namespace lang {
class LoweredFunc;
} // namespace lang
namespace ir {
class Module;
struct IrPrinter : public IRVisitorRequireReImpl<void> {
explicit IrPrinter(std::ostream &os) : os_(os), str_("") {}
//! Emit an expression on the output stream.
void Print(const Expr &e);
//! Emit a expression list with , splitted.
void Print(const std::vector<Expr> &exprs,
const std::string &splitter = ", ");
//! Emit a binary operator
template <typename IRN>
void PrintBinaryOp(const std::string &op, const BinaryOpNode<IRN> *x);
//! Prefix the current line with `indent_` spaces.
void DoIndent();
//! Increase the indent size.
void IncIndent();
//! Decrease the indent size.
void DecIndent();
std::ostream &os() { return os_; }
void Visit(const Expr &x) { IRVisitorRequireReImpl::Visit(&x); }
void Visit(const std::vector<Expr> &exprs,
const std::string &splitter = ", ") {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
Visit(exprs[i]);
str_ += splitter;
}
if (!exprs.empty()) Visit(exprs.back());
}
#define __(op__) void Visit(const op__ *x) override;
NODETY_FORALL(__)
#undef __
#define __(op__) virtual void Visit(const intrinsics::op__ *x);
INTRINSIC_KIND_FOR_EACH(__)
#undef __
protected:
std::string str_;
private:
std::ostream &os_;
uint16_t indent_{};
const int indent_unit{2};
};
std::ostream &operator<<(std::ostream &os, Expr a);
std::ostream &operator<<(std::ostream &os, const std::vector<Expr> &a);
std::ostream &operator<<(std::ostream &os, const Module &m);
template <typename IRN>
void IrPrinter::PrintBinaryOp(const std::string &op,
const BinaryOpNode<IRN> *x) {
str_ += "(";
Visit(x->a());
str_ += " ";
str_ += op;
str_ += " ";
Visit(x->b());
str_ += ")";
}
} // namespace ir
} // namespace cinn
......@@ -14,8 +14,8 @@
#include <unordered_set>
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/utils/ir_compare.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......@@ -23,8 +23,7 @@ namespace ir {
bool operator==(Expr a, Expr b) {
if (a.get() == b.get()) return true;
IrEqualVisitor cmp;
return cmp.Compare(a, b);
return ir_utils::IRCompare(a, b);
}
bool operator!=(Expr a, Expr b) { return !(a == b); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment