Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -19,3 +19,9 @@ cinn_cc_test(test_ir_compare SRCS ir_compare_test.cc DEPS cinncore)
cinn_cc_test(test_ir_copy SRCS ir_copy_test.cc DEPS cinncore)
cinn_cc_test(test_schedule_block_graph SRCS schedule_block_graph_test.cc DEPS
cinncore)
if(WITH_CUDA)
cinn_cc_test(
test_static_shape_group_scheduler SRCS st_shape_group_scheduler_test.cc
DEPS cinncore decomposer_test_helper)
endif()
......@@ -19,6 +19,7 @@
namespace cinn {
namespace ir {
namespace ir_utils {
TEST(CollectIRNodes, basic0) {
Expr C = Expr(1) + 2;
......@@ -41,15 +42,15 @@ TEST(CollectIRNodes, basic) {
auto C = Compute(
{M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C");
auto stages = CreateStages({C});
ast_gen_ius::TensorGroup tensor_group({C});
auto fn = Lower("fn", stages, {A, B, C});
auto fn = LowerToAst("fn", {A, B, C}, &tensor_group);
LOG(INFO) << "fn:\n" << fn;
auto tensors =
CollectIRNodes(fn, [](const Expr* x) { return x->as_tensor(); });
ASSERT_EQ(tensors.size(), 5UL);
ASSERT_EQ(tensors.size(), 3UL);
auto fn_body = fn.As<ir::_LoweredFunc_>()->body;
LOG(INFO) << "fn.body:\n" << fn_body;
......@@ -57,6 +58,6 @@ TEST(CollectIRNodes, basic) {
CollectIRNodes(fn_body, [](const Expr* x) { return x->as_tensor(); });
auto exprs = CollectIRNodes(fn_body, [](const Expr* x) { return x; });
}
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -23,7 +23,7 @@
namespace cinn {
namespace ir {
namespace ir_utils {
TEST(TestIrCompare, SingleFunction) {
Target target = common::DefaultHostTarget();
......@@ -128,20 +128,16 @@ TEST(TestIrCompare, SingleFunction) {
ASSERT_EQ(func2_str, utils::GetStreamCnt(funcs_2.front()));
ASSERT_EQ(func3_str, utils::GetStreamCnt(funcs_3.front()));
IrEqualVisitor compartor;
// they are different at the name of root ScheduleBlock
ASSERT_TRUE(compartor.Compare(funcs_1.front(), funcs_2.front()));
ASSERT_TRUE(IRCompare(funcs_1.front(), funcs_2.front()));
// compare with itself
ASSERT_TRUE(compartor.Compare(funcs_1.front(), funcs_1.front()));
IrEqualVisitor compartor_allow_suffix_diff(true);
ASSERT_TRUE(IRCompare(funcs_1.front(), funcs_1.front()));
// they are euqal if allowing suffix of name different
ASSERT_TRUE(
compartor_allow_suffix_diff.Compare(funcs_1.front(), funcs_2.front()));
ASSERT_TRUE(IRCompare(funcs_1.front(), funcs_2.front(), true));
ASSERT_FALSE(compartor.Compare(funcs_1.front(), funcs_3.front()));
ASSERT_FALSE(
compartor_allow_suffix_diff.Compare(funcs_1.front(), funcs_3.front()));
ASSERT_FALSE(IRCompare(funcs_1.front(), funcs_3.front()));
ASSERT_FALSE(IRCompare(funcs_1.front(), funcs_3.front(), true));
}
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -16,16 +16,17 @@
#include <gtest/gtest.h>
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace cinn {
namespace optim {
namespace ir {
namespace ir_utils {
TEST(IrCopy, basic) {
Expr a(1.f);
auto aa = IRCopy(a);
LOG(INFO) << "aa " << aa;
}
} // namespace optim
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include <gtest/gtest.h>
......
......@@ -18,12 +18,14 @@
#include "paddle/cinn/ir/op/ir_operators.h"
namespace cinn::ir {
namespace cinn {
namespace ir {
namespace ir_utils {
TEST(IrVerify, basic) {
Expr a(1);
Expr b(1);
IrVerify(a + b);
}
} // namespace cinn::ir
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -20,6 +20,8 @@
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
PD_DECLARE_bool(cinn_new_group_scheduler);
namespace cinn {
namespace ir {
......@@ -38,7 +40,8 @@ IRSchedule MakeIRSchedule(frontend::Program* program) {
"inferdtype");
auto& shape_dict = graph->GetMutableAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target);
auto op_lowerer =
hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target);
std::vector<LoweredFunc> lowered_funcs =
op_lowerer.Lower(graph->fusion_groups.front(), false, false);
......@@ -94,10 +97,11 @@ frontend::Program CreateReduceProgram() {
}
TEST(ScheduleBlockGraph, elementwise) {
Context::Global().ResetNameId();
frontend::Program program = CreateElementwiseProgram();
IRSchedule ir_sch = MakeIRSchedule(&program);
ScheduleBlockGraph sbg(ir_sch);
LOG(INFO) << GetIR(ir_sch);
ScheduleBlockGraph sbg(ir_sch);
LOG(INFO) << sbg.Visualize();
CHECK_EQ(sbg.BlockIdsInOrder().size(), 6);
CHECK_EQ(sbg.nodes().size(), 6);
......@@ -135,24 +139,73 @@ TEST(ScheduleBlockGraph, elementwise) {
#ifdef CINN_WITH_CUDA
TEST(ScheduleBlockGraph, reduce) {
if (FLAGS_cinn_new_group_scheduler) {
Context::Global().ResetNameId();
frontend::Program program = CreateReduceProgram();
IRSchedule ir_sch = MakeIRSchedule(&program);
ScheduleBlockGraph sbg(ir_sch);
LOG(INFO) << GetIR(ir_sch);
LOG(INFO) << sbg.Visualize();
CHECK_EQ(sbg.BlockIdsInOrder().size(), 8);
CHECK_EQ(sbg.nodes().size(), 8);
CHECK_EQ(sbg.BlockIdsInOrder().size(), 5);
CHECK_EQ(sbg.nodes().size(), 5);
ScheduleBlockNode* v_reduce_init = sbg.RetrieveNode("var_48__reduce_init");
ScheduleBlockNode* v_reduce_init = sbg.RetrieveNode("var_2__reduce_init");
CHECK(v_reduce_init);
CHECK_EQ(v_reduce_init->UpstreamNodes().size(), 0);
CHECK_EQ(v_reduce_init->DownstreamNodes().size(), 3);
ScheduleBlockNode* v = sbg.RetrieveNode("var_48");
ScheduleBlockNode* v = sbg.RetrieveNode("var_2");
CHECK(v);
CHECK_EQ(v->UpstreamNodes().size(), 5);
CHECK_EQ(v->UpstreamNodes().size(), 2);
CHECK_EQ(v->DownstreamNodes().size(), 2);
std::vector<std::string> reverse_dfs_topo_order_ids;
sbg.DFSTopoWalk(
[&reverse_dfs_topo_order_ids](const ScheduleBlockNode* node) {
reverse_dfs_topo_order_ids.push_back(node->id());
});
for (const std::string& id : reverse_dfs_topo_order_ids) {
LOG(INFO) << id;
}
CHECK_EQ(reverse_dfs_topo_order_ids.size(), 5);
std::vector<std::string> dfs_topo_order_ids;
sbg.DFSTopoWalk(
[&dfs_topo_order_ids](const ScheduleBlockNode* node) {
dfs_topo_order_ids.push_back(node->id());
},
false);
for (const std::string& id : dfs_topo_order_ids) {
LOG(INFO) << id;
}
CHECK_EQ(dfs_topo_order_ids.size(), 5);
}
}
TEST(ScheduleBlockGraph, arg_max) {
Context::Global().ResetNameId();
frontend::NetBuilder builder("net_builder");
auto x = builder.CreateInput(Float(32), {8, 16}, "X");
auto y = builder.Argmax(x, 0);
frontend::Program program = builder.Build();
IRSchedule ir_sch = MakeIRSchedule(&program);
LOG(INFO) << GetIR(ir_sch);
ScheduleBlockGraph sbg(ir_sch);
LOG(INFO) << sbg.Visualize();
CHECK_EQ(sbg.BlockIdsInOrder().size(), 3);
CHECK_EQ(sbg.nodes().size(), 3);
ScheduleBlockNode* v0_idx = sbg.RetrieveNode("var_0_index");
CHECK(v0_idx);
CHECK_EQ(v0_idx->UpstreamNodes().size(), 1);
CHECK_EQ(v0_idx->DownstreamNodes().size(), 1);
ScheduleBlockNode* v0 = sbg.RetrieveNode("var_0");
CHECK(v0);
CHECK_EQ(v0->UpstreamNodes().size(), 2);
CHECK_EQ(v0->DownstreamNodes().size(), 0);
std::vector<std::string> reverse_dfs_topo_order_ids;
sbg.DFSTopoWalk([&reverse_dfs_topo_order_ids](const ScheduleBlockNode* node) {
reverse_dfs_topo_order_ids.push_back(node->id());
......@@ -160,7 +213,7 @@ TEST(ScheduleBlockGraph, reduce) {
for (const std::string& id : reverse_dfs_topo_order_ids) {
LOG(INFO) << id;
}
CHECK_EQ(reverse_dfs_topo_order_ids.size(), 8);
CHECK_EQ(reverse_dfs_topo_order_ids.size(), 3);
std::vector<std::string> dfs_topo_order_ids;
sbg.DFSTopoWalk(
......@@ -171,7 +224,7 @@ TEST(ScheduleBlockGraph, reduce) {
for (const std::string& id : dfs_topo_order_ids) {
LOG(INFO) << id;
}
CHECK_EQ(dfs_topo_order_ids.size(), 8);
CHECK_EQ(dfs_topo_order_ids.size(), 3);
}
#endif
......
......@@ -19,9 +19,9 @@
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/utils/string.h"
#include "paddle/cinn/utils/type_defs.h"
......@@ -95,7 +95,7 @@ std::vector<ir::LoweredFunc> LowerCompute(
IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs) {
std::vector<Expr> exprs;
for (auto&& func : lowered_funcs) {
exprs.emplace_back(optim::IRCopy(func->body));
exprs.emplace_back(ir::ir_utils::IRCopy(func->body));
}
return ir::IRSchedule(ir::ModuleExpr(exprs));
}
......@@ -106,10 +106,11 @@ std::string SourceCodeGen(const ModuleExpr& module_expr,
const Target& target) {
auto exprs = module_expr.GetExprs();
CHECK_EQ(exprs.size(), lowered_funcs.size()) << "size of func is not euqal";
std::vector<ir::LoweredFunc> updated_funcs = optim::IRCopy(lowered_funcs);
std::vector<ir::LoweredFunc> updated_funcs =
ir::ir_utils::IRCopy(lowered_funcs);
Module::Builder builder("test_module", target);
for (auto i = 0; i < lowered_funcs.size(); ++i) {
updated_funcs[i]->body = optim::IRCopy(exprs.at(i));
updated_funcs[i]->body = ir::ir_utils::IRCopy(exprs.at(i));
builder.AddFunction(updated_funcs[i]);
}
auto module = builder.Build();
......@@ -778,6 +779,7 @@ TEST_F(TestScheduleDesc, StepKind_ReverseComputeInline) {
CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}
#ifdef CINN_WITH_CUDA
TEST_F(TestScheduleDesc, StepKind_Bind) {
lowered_funcs = LowerCompute({32, 128}, target);
ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
......@@ -793,6 +795,7 @@ TEST_F(TestScheduleDesc, StepKind_Bind) {
CheckReplayResult(ir_sch, trace);
CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}
#endif
TEST_F(TestScheduleDesc, StepKind_Rfactor) {
Expr M(32);
......@@ -839,12 +842,14 @@ TEST_F(TestScheduleDesc, StepKind_MergeExprs) {
auto funcs_1 =
LowerCompute({32, 32, 32}, target, true, "elementwise-add_const");
ir::IRSchedule ir_sch = ir::IRSchedule(ir::ModuleExpr(
{optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)}));
ir::IRSchedule ir_sch =
ir::IRSchedule(ir::ModuleExpr({ir::ir_utils::IRCopy(funcs_0[0]->body),
ir::ir_utils::IRCopy(funcs_0[0]->body)}));
ir_sch.MergeExprs();
trace.Append(ScheduleDesc::Step("MergeExprs", {}, {}, {}));
ir::IRSchedule replay_sch = ir::IRSchedule(ir::ModuleExpr(
{optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)}));
ir::IRSchedule replay_sch =
ir::IRSchedule(ir::ModuleExpr({ir::ir_utils::IRCopy(funcs_0[0]->body),
ir::ir_utils::IRCopy(funcs_0[0]->body)}));
trace.Replay(&replay_sch);
auto lhs_exprs = ir_sch.GetModule().GetExprs();
......
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h"
#include <gtest/gtest.h>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/frontend/decomposer/test_helper.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
PD_DECLARE_bool(cinn_new_group_scheduler);
namespace cinn {
namespace ir {
using frontend::NetBuilder;
using frontend::RunDecomposer;
void Compile(NetBuilder* net_builder) {
auto program = net_builder->Build();
auto target = common::DefaultTarget();
RunDecomposer(&program, target);
auto graph = std::make_shared<hlir::framework::Graph>(program, target);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
CHECK_EQ(graph->fusion_groups.size(), 1);
auto& dtype_dict =
graph->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>(
"inferdtype");
auto& shape_dict = graph->GetMutableAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer =
hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target);
for (auto& fusion_group : graph->fusion_groups) {
std::vector<ir::LoweredFunc> lowered_funcs =
op_lowerer.Lower(fusion_group,
/* apply_op_schedule = */ true,
/* apply_group_schedule = */ false);
CHECK_EQ(lowered_funcs.size(), 1);
VLOG(1) << "without group schedule, lowered_func: "
<< lowered_funcs.front();
FLAGS_cinn_new_group_scheduler = true;
lowered_funcs = op_lowerer.Lower(fusion_group,
/* apply_op_schedule = */ true,
/* apply_group_schedule = */ true);
CHECK_EQ(lowered_funcs.size(), 1);
VLOG(1) << "after group schedule, lowered_func: " << lowered_funcs.front();
}
}
void CheckAccuracy(NetBuilder* net_builder,
const std::vector<std::string>& input_names) {
FLAGS_cinn_new_group_scheduler = true;
auto program = net_builder->Build();
auto target = common::DefaultTarget();
auto graph = std::make_shared<hlir::framework::Graph>(program, target);
hlir::framework::ApplyPasses(graph.get(),
{"OpFusionPass", "FusionMergePass"});
VLOG(1) << "Before CheckFusionAccuracyPass:\n"
<< graph->DebugGroupedGraph(std::unordered_set<std::string>{});
hlir::framework::ApplyPasses(
graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"});
VLOG(1) << "After CheckFusionAccuracyPass:\n"
<< graph->DebugGroupedGraph(std::unordered_set<std::string>{});
auto scope = BuildScope(target, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
for (size_t i = 0; i < input_names.size(); ++i) {
scope->Var<hlir::framework::Tensor>(input_names[i]);
auto tensor = scope->GetTensor(input_names[i]);
std::vector<float> vec;
frontend::InitRandomVector<float>(
&vec, tensor->shape().numel(), 0.0f, 1.0f);
frontend::CopyFromVector<float>(vec, tensor, target);
}
auto runtime_program = gc.Build();
runtime_program->Execute();
}
// Each unittest below tests a single reduce,
// these unittests are only used to observe the generated IR and debug.
// Accuracy testing is guaranteed by Python unittests named
// test_reduce_op_xxx.py.
TEST(GROUP_SCHEDULER, last_reduce_only_1) {
NetBuilder net_builder("last_reduce_only_1");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {128, 64, 32}, "A");
auto B = net_builder.ReduceSum(A, {2});
};
CreateModel();
Compile(&net_builder);
}
TEST(GROUP_SCHEDULER, last_reduce_only_2) {
NetBuilder net_builder("last_reduce_only_2");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {1024}, "A");
auto B = net_builder.ReduceSum(A, {0});
};
CreateModel();
Compile(&net_builder);
}
TEST(GROUP_SCHEDULER, last_reduce_only_3) {
NetBuilder net_builder("last_reduce_only_3");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {512, 256}, "A");
auto B = net_builder.ReduceSum(A, {1});
};
CreateModel();
Compile(&net_builder);
}
TEST(GROUP_SCHEDULER, non_last_reduce_only_1) {
NetBuilder net_builder("non_last_reduce_only_1");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {10, 10, 10}, "A");
auto B = net_builder.ReduceSum(A, {0, 1}, /* keep_dim = */ true);
};
CreateModel();
Compile(&net_builder);
}
TEST(GROUP_SCHEDULER, non_last_reduce_only_2) {
NetBuilder net_builder("non_last_reduce_only_2");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {64, 32, 16, 8, 4}, "A");
auto B = net_builder.ReduceSum(A, {1, 2, 3}, /* keep_dim = */ true);
};
CreateModel();
Compile(&net_builder);
}
TEST(GROUP_SCHEDULER, shuffle_reduce_only_1) {
NetBuilder net_builder("shuffle_reduce_only_1");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {32, 32, 32, 32}, "A");
auto B = net_builder.ReduceSum(A, {0, 2, 3});
};
CreateModel();
Compile(&net_builder);
}
TEST(GROUP_SCHEDULER, shuffle_reduce_only_2) {
NetBuilder net_builder("shuffle_reduce_only_2");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {32, 64, 56, 56}, "A");
auto B = net_builder.ReduceSum(A, {0, 2, 3});
};
CreateModel();
Compile(&net_builder);
}
// Each of the following unittest tests a basic pattern composed of multiple
// basic op. And apply accuracy checks to ensure that the results of fusion
// groups and independently running each op are consistent.
TEST(GROUP_SCHEDULER, elementwise_1) {
int h = 128, w = 128;
NetBuilder net_builder("elementwise_1");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h, w}, "B");
auto C = net_builder.Add(A, B);
auto D = net_builder.Add(B, C);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_2) {
int h = 128, w = 128;
NetBuilder net_builder("elementwise_2");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h, w}, "B");
auto C = net_builder.Add(A, B);
auto D = net_builder.Cast(C, "float16");
auto E = net_builder.Cast(C, "float16");
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_3) {
int h = 128, w = 128;
NetBuilder net_builder("elementwise_3");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h, w}, "B");
auto C = net_builder.Add(A, B);
auto D = net_builder.Cast(C, "float16");
auto E = net_builder.Cast(C, "float16");
auto F = net_builder.Cast(D, "float32");
auto G = net_builder.Cast(E, "float32");
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_4) {
int h = 128, w = 128;
NetBuilder net_builder("elementwise_4");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h, w}, "B");
auto C = net_builder.Add(A, B);
auto D = net_builder.Cast(C, "float16");
auto E = net_builder.Cast(C, "float16");
auto F = net_builder.Add(D, E);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_broadcast) {
NetBuilder net_builder("elementwise_broadcast");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {128}, "A");
auto B = net_builder.CreateInput(Float(32), {128}, "B");
auto C = net_builder.Add(A, B);
auto D = net_builder.BroadcastTo(C, {128, 128});
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_double_broadcast) {
NetBuilder net_builder("elementwise_double_broadcast");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {128}, "A");
auto B = net_builder.CreateInput(Float(32), {128}, "B");
auto C = net_builder.Add(A, B);
auto D = net_builder.BroadcastTo(C, {128, 128});
auto E = net_builder.BroadcastTo(C, {128, 128});
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, non_last_reduce_elementwise_1) {
int h = 128, w = 128;
NetBuilder net_builder("non_last_reduce_elementwise_1");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.ReduceSum(A, {0});
auto C = net_builder.Cast(B, "float16");
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, last_reduce_elementwise) {
NetBuilder net_builder("last_reduce_elementwise");
std::vector<std::string> input_names = {"A", "C"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {128, 64}, "A");
auto B = net_builder.ReduceSum(A, {1});
auto C = net_builder.CreateInput(Float(32), {128}, "C");
auto D = net_builder.Add(B, C);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_1) {
NetBuilder net_builder("keep_dim_reduce_elementwise");
std::vector<std::string> input_names = {"A", "C"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {16, 64, 112, 112}, "A");
auto B = net_builder.CreateInput(Float(32), {1, 64, 1, 1}, "B");
auto C = net_builder.ReduceSum(A, {0, 2, 3}, true);
auto D = net_builder.Add(B, C);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_2) {
NetBuilder net_builder("keep_dim_reduce_elementwise_2");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {16, 64, 112, 112}, "A");
auto B = net_builder.CreateInput(Float(32), {16, 64, 1, 1}, "B");
auto C = net_builder.ReduceSum(A, {2, 3}, true);
auto D = net_builder.Add(B, C);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_3) {
NetBuilder net_builder("keep_dim_reduce_elementwise_3");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {16, 64, 2048}, "A");
auto B = net_builder.CreateInput(Float(32), {16, 64, 1}, "B");
auto C = net_builder.ReduceSum(A, {2}, true);
auto D = net_builder.Add(B, C);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_4) {
NetBuilder net_builder("keep_dim_reduce_elementwise_4");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {16, 64, 2048}, "A");
auto B = net_builder.CreateInput(Float(32), {16, 1, 2048}, "B");
auto C = net_builder.ReduceSum(A, {1}, true);
auto D = net_builder.Add(B, C);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_5) {
NetBuilder net_builder("keep_dim_reduce_elementwise_5");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {16, 64, 16, 1024}, "A");
auto B = net_builder.CreateInput(Float(32), {16, 1, 16, 1}, "B");
auto C = net_builder.ReduceSum(A, {1, 3}, true);
auto D = net_builder.Add(B, C);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_non_last_reduce) {
int h = 128, w = 128;
NetBuilder net_builder("elementwise_non_last_reduce");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h, w}, "B");
auto C = net_builder.Add(A, B);
auto D = net_builder.ReduceSum(C, {0});
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_last_reduce) {
int h = 128, w = 128;
NetBuilder net_builder("elementwise_last_reduce");
std::vector<std::string> input_names = {"A", "C"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h, w}, "B");
auto C = net_builder.Add(A, B);
auto D = net_builder.ReduceSum(C, {1});
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_non_last_reduce_elementwise) {
int h = 128, w = 128;
NetBuilder net_builder("elementwise_non_last_reduce_elementwise");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h, w}, "B");
auto C = net_builder.Add(A, B);
auto E = net_builder.ReduceSum(C, {0});
auto F = net_builder.Cast(E, "float16");
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_last_reduce_elementwise) {
int h = 128, w = 128;
NetBuilder net_builder("elementwise_non_last_reduce_elementwise");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h, w}, "B");
auto C = net_builder.Add(A, B);
auto E = net_builder.ReduceSum(C, {1});
auto F = net_builder.Cast(E, "float16");
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_double_non_last_reduce_elementwise) {
int h = 128, w = 128;
NetBuilder net_builder("elementwise_double_non_last_reduce_elementwise");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h, w}, "B");
auto C = net_builder.Add(A, B);
auto E = net_builder.ReduceSum(C, {0});
auto F = net_builder.ReduceSum(C, {0});
auto G = net_builder.Add(E, F);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, double_non_last_reduce_elementwise) {
int h = 128, w = 128;
NetBuilder net_builder("double_non_last_reduce_elementwise");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h * 2, w}, "B");
auto E = net_builder.ReduceSum(A, {0});
auto F = net_builder.ReduceSum(B, {0});
auto G = net_builder.Add(E, F);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, triple_non_last_reduce) {
int h = 128, w = 1024;
NetBuilder net_builder("triple_non_last_reduce");
std::vector<std::string> input_names = {"A", "B"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {128, 1024}, "A");
auto B = net_builder.ReduceSum(A, {0});
auto C = net_builder.ReduceSum(A, {0});
auto D = net_builder.ReduceSum(A, {0});
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, reduce_broadcast_1) {
int h = 32, w = 32;
NetBuilder net_builder("reduce_broadcast_1");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h * w}, "A");
auto B = net_builder.ReduceSum(A, {0});
auto C = net_builder.BroadcastTo(B, {h * w}, {0});
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, reduce_broadcast_2) {
int h = 32, w = 32;
NetBuilder net_builder("reduce_broadcast_2");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.ReduceSum(A, {0, 1});
auto C = net_builder.BroadcastTo(B, {h, w}, {1});
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, reduce_broadcast_3) {
int h = 32, w = 32;
NetBuilder net_builder("reduce_broadcast_3");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A");
auto B = net_builder.ReduceSum(A, {1, 2});
auto C = net_builder.BroadcastTo(B, {h, h, w}, {0});
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, reduce_broadcast_reduce_broadcast) {
int h = 32, w = 32;
NetBuilder net_builder("reduce_broadcast_reduce_broadcast");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A");
auto B = net_builder.ReduceSum(A, {1, 2});
auto C = net_builder.BroadcastTo(B, {h, h, w}, {0});
auto D = net_builder.ReduceSum(C, {1, 2});
auto E = net_builder.BroadcastTo(D, {h, h, w}, {0});
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, reduce_broadcast_elementwise) {
int h = 32, w = 32;
NetBuilder net_builder("reduce_broadcast_elementwise");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A");
auto B = net_builder.ReduceSum(A, {1, 2});
auto C = net_builder.BroadcastTo(B, {h, h, w}, {0});
auto D = net_builder.CreateInput(Float(32), {h, w}, "B");
auto E = net_builder.BroadcastTo(D, {h, h, w}, {1, 2});
auto F = net_builder.Add(C, E);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_double_reduce_elementwise_1) {
NetBuilder net_builder("elementwise_double_reduce_elementwise_1");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {32, 32}, "A");
auto B = net_builder.CreateInput(Float(32), {32, 32}, "B");
auto C = net_builder.Add(A, B);
auto D = net_builder.ReduceSum(C, {1}, false);
auto E = net_builder.ReduceSum(C, {1}, false);
auto F = net_builder.Add(D, E);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, elementwise_double_reduce_elementwise_2) {
NetBuilder net_builder("elementwise_double_reduce_elementwise_2");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
auto A = net_builder.CreateInput(Float(32), {1, 1000}, "A");
auto B = net_builder.CreateInput(Float(32), {1, 1000}, "B");
auto C = net_builder.Add(A, B);
auto D = net_builder.ReduceSum(C, {1}, false);
auto E = net_builder.ReduceSum(C, {1}, false);
auto F = net_builder.Add(D, E);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
// Each of following unittests tests a group composed of typical operators
TEST(GROUP_SCHEDULER, layernorm) {
int h = 32, w = 1024;
NetBuilder net_builder("layernorm");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
// x
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
// x * x
auto B = net_builder.Multiply(A, A);
// sum x
auto C = net_builder.ReduceSum(A, {1});
// sum x*x
auto D = net_builder.ReduceSum(B, {1});
// constant w
auto E = net_builder.FillConstant<float>({h}, 1024.0f, "E");
// mean
auto F = net_builder.Divide(C, E);
auto FF = net_builder.BroadcastTo(F, {h, w}, {0});
// mean x*x
auto G = net_builder.Divide(D, E);
// mean * mean
auto H = net_builder.Multiply(F, F);
// var^2
auto I = net_builder.Subtract(G, H);
// eps
auto J = net_builder.FillConstant<float>({h}, 1e-10f, "J");
// eps + delta
auto K = net_builder.Add(I, J);
// var
auto L = net_builder.Sqrt(K);
auto LL = net_builder.BroadcastTo(L, {h, w}, {0});
// x - mean
auto M = net_builder.Subtract(A, FF);
// /var
auto N = net_builder.Divide(M, LL);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
TEST(GROUP_SCHEDULER, softmax) {
int h = 32, w = 1024;
NetBuilder net_builder("softmax");
std::vector<std::string> input_names = {"A"};
// create model
auto CreateModel = [&]() {
// softmax
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
// reduce max
auto B = net_builder.ReduceMax(A, {1});
// broadcast
auto C = net_builder.BroadcastTo(B, {h, w}, {0});
// x - max(x)
auto D = net_builder.Subtract(A, C);
// exp(x)
auto E = net_builder.Exp(D);
// reduce sum
auto F = net_builder.ReduceSum(E, {1});
// broadcast
auto G = net_builder.BroadcastTo(F, {h, w}, {0});
// exp(x)/sum(exp(x))
auto H = net_builder.Divide(E, G);
};
CreateModel();
Compile(&net_builder);
CreateModel();
CheckAccuracy(&net_builder, input_names);
}
} // namespace ir
} // namespace cinn
......@@ -20,8 +20,8 @@
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/test_helper.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
......
......@@ -3,10 +3,8 @@ core_gather_headers()
gather_srcs(
cinnapi_src
SRCS
ir_visitor.cc
ir_mutator.cc
ir_printer.cc
ir_verify.cc
ir_compare.cc
ir_nodes_collector.cc
ir_copy.cc)
ir_copy.cc
ir_replace.cc)
......@@ -17,16 +17,22 @@
#include <regex>
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace cinn {
namespace ir {
namespace ir_utils {
bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
if (lhs.get() == rhs.get()) { // the same object, including both are null
return true;
}
if (only_compare_structure_ && !lhs.defined() && !rhs.defined()) {
return true;
}
if (!lhs.defined() || !rhs.defined()) { // someone invalid
return false;
VLOG(5) << "Not equal on Expr, someone not defined";
......@@ -44,10 +50,9 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
return equal;
}
bool IrEqualVisitor::Compare(const std::string& lhs,
const std::string& rhs,
bool allow_name_suffix_diff) {
// if allow_name_suffix_diff=true then just compare the name prefix before the
bool IrEqualVisitor::Compare(const std::string& lhs, const std::string& rhs) {
// if allow_name_suffix_diff_=true then just compare the name prefix before
// the
// "_[0-9]+"
auto common_len = 0;
for (; common_len < lhs.size() && common_len < rhs.size(); ++common_len) {
......@@ -65,7 +70,7 @@ bool IrEqualVisitor::Compare(const std::string& lhs,
equal = true;
} else {
equal = false;
if (allow_name_suffix_diff) {
if (allow_name_suffix_diff_) {
equal = is_endswith_index(lhs) && is_endswith_index(rhs);
}
}
......@@ -179,17 +184,26 @@ bool IrEqualVisitor::Visit(const Block* lhs, const Expr* other) {
bool IrEqualVisitor::Visit(const Call* lhs, const Expr* other) {
auto* rhs = other->As<Call>();
return lhs->name == rhs->name && Compare(lhs->read_args, rhs->read_args) &&
bool flag = Compare(lhs->read_args, rhs->read_args) &&
Compare(lhs->write_args, rhs->write_args) &&
Compare(lhs->attrs, rhs->attrs) && lhs->call_type == rhs->call_type;
Compare(lhs->attrs, rhs->attrs) &&
lhs->call_type == rhs->call_type;
if (only_compare_structure_) {
return flag;
}
return lhs->name == rhs->name && flag;
// TODO(CtfGo): Compare `func` field
}
bool IrEqualVisitor::Visit(const _Var_* lhs, const Expr* other) {
auto* rhs = other->As<_Var_>();
return lhs->name == rhs->name &&
Compare(lhs->lower_bound, rhs->lower_bound) &&
Compare(lhs->upper_bound, rhs->upper_bound) && lhs->tag == rhs->tag;
bool flag = Compare(lhs->lower_bound, rhs->lower_bound) &&
Compare(lhs->upper_bound, rhs->upper_bound) &&
lhs->tag == rhs->tag;
if (only_compare_structure_) {
return flag;
}
return lhs->name == rhs->name && flag;
}
bool IrEqualVisitor::Visit(const Load* lhs, const Expr* other) {
......@@ -219,19 +233,25 @@ bool IrEqualVisitor::Visit(const Free* lhs, const Expr* other) {
bool IrEqualVisitor::Visit(const _Buffer_* lhs, const Expr* other) {
auto* rhs = other->As<_Buffer_>();
return Compare(lhs->shape, rhs->shape) &&
Compare(lhs->strides, rhs->strides) && lhs->name == rhs->name &&
lhs->scope == rhs->scope &&
Compare(lhs->elem_offset, rhs->elem_offset) &&
lhs->offset_factor == rhs->offset_factor &&
lhs->target == rhs->target &&
bool flag =
Compare(lhs->shape, rhs->shape) && Compare(lhs->strides, rhs->strides) &&
lhs->scope == rhs->scope && Compare(lhs->elem_offset, rhs->elem_offset) &&
lhs->offset_factor == rhs->offset_factor && lhs->target == rhs->target &&
lhs->data_alignment == rhs->data_alignment &&
lhs->memory_type == rhs->memory_type && lhs->dtype == rhs->dtype;
if (only_compare_structure_) {
return flag;
}
return flag && lhs->name == rhs->name;
}
bool IrEqualVisitor::Visit(const _Tensor_* lhs, const Expr* other) {
auto* rhs = other->As<_Tensor_>();
return lhs->name == rhs->name && Compare(lhs->shape, rhs->shape);
bool flag = Compare(lhs->shape, rhs->shape);
if (only_compare_structure_) {
return flag;
}
return flag && Compare(lhs->name, rhs->name);
}
bool IrEqualVisitor::Visit(const _LoweredFunc_* lhs, const Expr* other) {
......@@ -280,10 +300,15 @@ bool IrEqualVisitor::Visit(const _LoweredFunc_* lhs, const Expr* other) {
bool IrEqualVisitor::Visit(const _Module_* lhs, const Expr* other) {
auto* rhs = other->As<_Module_>();
return lhs->name == rhs->name && lhs->target == rhs->target &&
Compare(lhs->buffers, rhs->buffers) &&
bool flag = Compare(lhs->buffers, rhs->buffers) &&
Compare(lhs->functions, rhs->functions) &&
Compare(lhs->submodules, rhs->submodules);
if (only_compare_structure_) {
return flag;
}
return flag && lhs->name == rhs->name;
}
bool IrEqualVisitor::Visit(const Let* lhs, const Expr* other) {
......@@ -345,11 +370,16 @@ bool IrEqualVisitor::Visit(const _BufferRange_* lhs, const Expr* other) {
bool IrEqualVisitor::Visit(const ScheduleBlock* lhs, const Expr* other) {
auto* rhs = other->As<ScheduleBlock>();
return Compare(lhs->name, rhs->name, allow_name_suffix_diff_) &&
Compare(lhs->iter_vars, rhs->iter_vars) &&
bool flag = Compare(lhs->iter_vars, rhs->iter_vars) &&
Compare(lhs->read_buffers, rhs->read_buffers) &&
Compare(lhs->write_buffers, rhs->write_buffers) &&
Compare(lhs->attrs, rhs->attrs) && Compare(lhs->body, rhs->body);
Compare(lhs->body, rhs->body);
if (only_compare_structure_) {
return flag;
}
return flag && Compare(lhs->attrs, rhs->attrs) &&
Compare(lhs->name, rhs->name);
}
bool IrEqualVisitor::Visit(const ScheduleBlockRealize* lhs, const Expr* other) {
......@@ -358,5 +388,18 @@ bool IrEqualVisitor::Visit(const ScheduleBlockRealize* lhs, const Expr* other) {
Compare(lhs->schedule_block, rhs->schedule_block);
}
bool IrEqualVisitor::Visit(const _Dim_* lhs, const Expr* other) {
auto* rhs = other->As<_Dim_>();
return lhs->name == rhs->name &&
lhs->GetSymbolName() == rhs->GetSymbolName() &&
lhs->GetRealDimSize() == rhs->GetRealDimSize();
}
bool IRCompare(const Expr& lhs, const Expr& rhs, bool allow_name_suffix_diff) {
IrEqualVisitor ir_equal_visitor(allow_name_suffix_diff);
return ir_equal_visitor.Compare(lhs, rhs);
}
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -16,24 +16,25 @@
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/ir/ir_visitor.h"
namespace cinn {
namespace ir {
namespace ir_utils {
// Determine whether two ir AST trees are euqal by comparing their struct and
// fields of each node through dfs visitor
class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {
public:
explicit IrEqualVisitor(bool allow_name_suffix_diff = false)
: allow_name_suffix_diff_(allow_name_suffix_diff) {}
explicit IrEqualVisitor(bool allow_name_suffix_diff = false,
bool only_compare_structure = false)
: allow_name_suffix_diff_(allow_name_suffix_diff),
only_compare_structure_(only_compare_structure) {}
// Return true if they are euqal, otherwise false;
bool Compare(const Expr& lhs, const Expr& rhs);
private:
bool Compare(const std::string& lhs,
const std::string& rhs,
bool allow_name_suffix_diff = false);
bool Compare(const std::string& lhs, const std::string& rhs);
bool Compare(const std::map<std::string, attr_t>& lhs,
const std::map<std::string, attr_t>& rhs);
template <typename T>
......@@ -45,7 +46,14 @@ class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {
// whether allowing name suffix ends with "_[0-9]+" different
bool allow_name_suffix_diff_ = false;
// not compare name field of Expr
bool only_compare_structure_ = false;
};
bool IRCompare(const Expr& lhs,
const Expr& rhs,
bool allow_name_suffix_diff = false);
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -21,15 +21,15 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace optim {
using namespace ir; // NOLINT
namespace ir {
namespace ir_utils {
namespace {
struct IRCopyVisitor : public ir::IRVisitorRequireReImpl<Expr> {
// Use maps to unify all the copied tensors and buffers.
std::map<std::string, ir::_Tensor_*> tensor_map;
......@@ -241,6 +241,7 @@ struct IRCopyVisitor : public ir::IRVisitorRequireReImpl<Expr> {
std::vector<Expr> buffers;
std::vector<Expr> functions;
std::vector<Expr> submodules;
std::vector<Expr> predicates;
for (auto& expr : op->buffers) {
buffers.push_back(Visit(&expr));
......@@ -254,10 +255,15 @@ struct IRCopyVisitor : public ir::IRVisitorRequireReImpl<Expr> {
submodules.push_back(Visit(&expr));
}
for (auto& expr : op->predicates) {
predicates.push_back(Visit(&expr));
}
auto res = ir::_Module_::Make(op->name, op->target);
res->buffers = buffers;
res->functions = functions;
res->submodules = submodules;
res->predicates = predicates;
return Expr(res);
}
......@@ -407,6 +413,10 @@ struct IRCopyVisitor : public ir::IRVisitorRequireReImpl<Expr> {
Visit(&op->schedule_block));
}
Expr Visit(const ir::_Dim_* op) override {
return ir::_Dim_::Make(op->name, op->sym_dim);
}
#define __(x__) Expr Visit(const ir::intrinsics::x__* op);
INTRINSIC_KIND_FOR_EACH(__)
#undef __
......@@ -474,7 +484,7 @@ Expr IRCopyVisitor::Visit(const ir::intrinsics::BuiltinIntrin* op) {
return intrinsics::BuiltinIntrin::Make(
op->name, op->args, op->id, op->arg_nums, op->type());
}
} // namespace
Expr IRCopy(Expr x) {
IRCopyVisitor visitor;
auto copied = visitor.Visit(&x);
......@@ -507,6 +517,6 @@ std::vector<ir::LoweredFunc> IRCopy(const std::vector<ir::LoweredFunc>& x) {
}
return res;
}
} // namespace optim
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -24,9 +24,8 @@ namespace cinn {
namespace ir {
class ModuleExpr;
} // namespace ir
namespace optim {
namespace ir_utils {
//! Shallow copy an expression.
Expr IRCopy(Expr x);
......@@ -39,5 +38,6 @@ ir::LoweredFunc IRCopy(const ir::LoweredFunc& x);
std::vector<ir::LoweredFunc> IRCopy(const std::vector<ir::LoweredFunc>& x);
} // namespace optim
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -15,14 +15,14 @@
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include <glog/logging.h>
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace cinn {
namespace ir {
namespace ir_utils {
namespace {
struct IrNodesCollector : public IRVisitorRequireReImpl<void> {
using teller_t = std::function<bool(const Expr*)>;
using handler_t = std::function<void(const Expr*)>;
......@@ -207,5 +207,116 @@ std::set<Expr> CollectReferencedTensors(
return ts0;
}
std::vector<std::string> CollectUndefinedVars(const Expr* e) {
struct Mutator : public ir::IRMutator<const Expr*> {
using ir::IRMutator<const Expr*>::Visit;
std::vector<std::string> undefined_vars;
std::set<std::string> defined_vars;
std::set<std::string> used_vars;
void CollectVarDef(const std::string& var) {
CHECK(!defined_vars.count(var))
<< "var " << var << " has been defined, please check";
CHECK(!used_vars.count(var))
<< "var " << var << " is wrongly used before definition";
defined_vars.insert(var);
}
void ClearVar(const std::string& var) {
defined_vars.erase(var);
used_vars.erase(var);
}
void CollectVarUse(const std::string& var) {
used_vars.insert(var);
if (defined_vars.count(var) == 0) {
undefined_vars.push_back(var);
}
}
void Visit(const ir::Let* op, const Expr* expr) override {
Expr symbol = op->symbol;
auto var = symbol.as_var_ref();
CHECK(var.defined());
CollectVarDef(var->name);
auto* node = expr->As<ir::Let>();
Visit(&node->body, &node->body);
}
void Visit(const ir::For* op, const Expr* expr) override {
CollectVarDef(op->loop_var->name);
auto* node = expr->As<ir::For>();
Visit(&node->min, &node->min);
Visit(&node->extent, &node->extent);
Visit(&node->body, &node->body);
ClearVar(op->loop_var->name);
}
void Visit(const ir::Load* op, const Expr* expr) override {
auto tensor = op->tensor.as_tensor_ref();
CollectVarUse(tensor->name);
auto* node = expr->As<ir::Load>();
for (auto& idx : node->indices) Visit(&idx, &idx);
}
void Visit(const ir::Store* op, const Expr* expr) override {
auto tensor = op->tensor.as_tensor_ref();
CollectVarUse(tensor->name);
auto* node = expr->As<ir::Store>();
for (auto& idx : node->indices) Visit(&idx, &idx);
Visit(&node->value, &node->value);
}
void Visit(const ir::_Var_* op, const Expr* expr) override {
CollectVarUse(op->name);
auto* node = expr->As<ir::_Var_>();
if (node->lower_bound.defined()) {
Visit(&node->lower_bound, &node->lower_bound);
}
if (node->upper_bound.defined()) {
Visit(&node->upper_bound, &node->upper_bound);
}
}
void Visit(const ir::Reduce* op, const Expr* expr) override {
for (auto& axis : op->reduce_axis) {
CollectVarDef(axis->name);
}
auto* node = expr->As<ir::Reduce>();
if (node->init.defined()) Visit(&node->init, &node->init);
Visit(&node->body, &node->body);
}
};
Mutator mutator;
mutator.Visit(e, e);
return mutator.undefined_vars;
}
std::set<std::string> CollectTensorNeedsWrite(const Expr* e) {
std::set<std::string> tensor_written;
IrNodesCollector::handler_t handler = [&](const Expr* x) {
if (x->As<ir::Store>()) {
tensor_written.insert(
x->As<ir::Store>()->tensor.As<ir::_Tensor_>()->name);
}
if (x->As<ir::_Tensor_>()) {
tensor_written.insert(x->As<ir::_Tensor_>()->name);
}
};
IrNodesCollector::teller_t teller = [](const Expr* x) {
if (x->As<ir::Store>() && x->As<ir::Store>()->tensor.As<ir::_Tensor_>()) {
return true;
}
if (x->As<ir::_Tensor_>() && x->As<ir::_Tensor_>()->is_call_node()) {
return true;
}
return false;
};
IrNodesCollector collector(std::move(teller), std::move(handler), false);
collector.Visit(e);
return tensor_written;
}
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -18,7 +18,7 @@
namespace cinn {
namespace ir {
namespace ir_utils {
/**
* Collect the IR Nodes(without duplication) in the expression.
*/
......@@ -65,5 +65,24 @@ std::map<std::string, Expr> CollectTensorMap(
return true;
});
/**
* Collect undefined vars in the scope.
*
* e.g.
*
* The expression:
* for i
* for j
* a[i, j] = b[i, j]
*
* here a, b are vars without definition
*/
std::vector<std::string> CollectUndefinedVars(const Expr* e);
/**
* Collect the Tensor Nodes which will be Writed by Store or Call Nodes
*/
std::set<std::string> CollectTensorNeedsWrite(const Expr* e);
} // namespace ir_utils
} // namespace ir
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/utils/ir_printer.h"
#include <algorithm>
#include <iomanip>
#include <limits>
#include <vector>
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace ir {
using common::bfloat16;
using common::float16;
void IrPrinter::Print(const Expr &e) {
IRVisitorRequireReImpl::Visit(&e);
os_ << str_;
str_ = "";
}
void IrPrinter::Print(const std::vector<Expr> &exprs,
const std::string &splitter) {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
Visit(exprs[i]);
str_ += splitter;
}
if (!exprs.empty()) Visit(exprs.back());
os_ << str_;
str_ = "";
}
void IrPrinter::Visit(const IntImm *x) {
if (x->type().is_int(64)) {
str_ += std::to_string(x->value);
str_ += "ll";
} else if (x->type().is_int(32)) {
str_ += std::to_string(x->value);
} else if (x->type().is_int(16)) {
str_ += "(int16_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_int(8)) {
str_ += "(int8_t)";
str_ += std::to_string(x->value);
} else {
LOG(FATAL) << "Not support int type: " << x->type();
}
}
void IrPrinter::Visit(const UIntImm *x) {
if (x->type().is_uint(64)) {
str_ += std::to_string(x->value);
str_ += "ull";
} else if (x->type().is_uint(32)) {
str_ += std::to_string(x->value);
} else if (x->type().is_uint(16)) {
str_ += "(uint16_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_uint(8)) {
str_ += "(uint8_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_uint(1)) {
if (x->value) {
str_ += "true";
} else {
str_ += "false";
}
} else {
LOG(FATAL) << "Not support uint type: " << x->type();
}
}
void IrPrinter::Visit(const FloatImm *x) {
std::ostringstream ss;
if (x->type().is_float16()) {
if (std::isinf(x->value)) {
ss << "cinn::common::raw_uint16_to_float16(0x7c00)";
} else if (std::isnan(x->value)) {
ss << "cinn::common::raw_uint16_to_float16(0x7e00)";
} else {
ss << "(float16)";
ss << std::setprecision(std::numeric_limits<float16>::max_digits10);
ss << static_cast<float16>(x->value) << "f";
}
} else if (x->type().is_bfloat16()) {
if (std::isinf(x->value)) {
ss << "cinn::common::raw_uint16_to_bfloat16(0x7F80)";
} else if (std::isnan(x->value)) {
ss << "cinn::common::raw_uint16_to_bfloat16(0x7FC0)";
} else {
ss << "(bfloat16)";
ss << std::setprecision(std::numeric_limits<bfloat16>::max_digits10);
ss << static_cast<bfloat16>(x->value) << "f";
}
} else if (x->type().is_float(32)) {
ss << std::setprecision(std::numeric_limits<float>::max_digits10);
ss << std::showpoint;
ss << x->value;
if (std::isfinite(x->value)) {
ss << "f";
}
} else if (x->type().is_float(64)) {
ss << std::setprecision(std::numeric_limits<double>::max_digits10);
ss << std::showpoint;
ss << x->value;
} else {
LOG(FATAL) << "Not support float type: " << x->type();
}
str_ += ss.str();
}
void IrPrinter::Visit(const StringImm *x) {
str_ += "\"";
str_ += x->value;
str_ += "\"";
}
void IrPrinter::Visit(const Add *x) { PrintBinaryOp("+", x); }
void IrPrinter::Visit(const Sub *x) { PrintBinaryOp("-", x); }
void IrPrinter::Visit(const Mul *x) { PrintBinaryOp("*", x); }
void IrPrinter::Visit(const Div *x) { PrintBinaryOp("/", x); }
void IrPrinter::Visit(const Mod *x) { PrintBinaryOp("%", x); }
void IrPrinter::Visit(const EQ *x) { PrintBinaryOp("==", x); }
void IrPrinter::Visit(const NE *x) { PrintBinaryOp("!=", x); }
void IrPrinter::Visit(const LT *x) { PrintBinaryOp("<", x); }
void IrPrinter::Visit(const LE *x) { PrintBinaryOp("<=", x); }
void IrPrinter::Visit(const GT *x) { PrintBinaryOp(">", x); }
void IrPrinter::Visit(const GE *x) { PrintBinaryOp(">=", x); }
void IrPrinter::Visit(const And *x) { PrintBinaryOp("and", x); }
void IrPrinter::Visit(const Or *x) { PrintBinaryOp("or", x); }
void IrPrinter::Visit(const Not *x) {
str_ += "!";
Visit(x->v());
}
void IrPrinter::Visit(const Min *x) {
str_ += "cinn_min(";
Visit(x->a());
str_ += ", ";
Visit(x->b());
str_ += ")";
}
void IrPrinter::Visit(const Max *x) {
str_ += "cinn_max(";
Visit(x->a());
str_ += ", ";
Visit(x->b());
str_ += ")";
}
void IrPrinter::Visit(const Minus *x) {
str_ += "-(";
Visit(x->v());
str_ += ")";
}
void IrPrinter::Visit(const For *x) {
if (x->is_parallel()) {
str_ += "parallel for (";
} else if (x->is_unrolled()) {
str_ += "unroll for (";
} else if (x->is_vectorized()) {
int factor = x->vectorize_info().factor;
str_ += "vectorize[";
str_ += std::to_string(factor);
str_ += "] for (";
} else if (x->is_binded()) {
auto &bind_info = x->bind_info();
if (bind_info.valid()) {
char axis_name = 'x' + bind_info.offset;
auto for_type = bind_info.for_type;
std::string prefix =
for_type == ForType::GPUBlock ? "blockIdx." : "threadIdx.";
str_ += "thread_bind[";
str_ += prefix;
str_ += axis_name;
str_ += "] for (";
} else {
str_ += "thread_bind[invalid info] for (";
}
} else if (x->is_serial()) {
str_ += "serial for (";
} else if (x->is_default()) {
str_ += "default for (";
} else {
str_ += "for (";
}
Visit(x->loop_var);
str_ += ", ";
Visit(x->min);
str_ += ", ";
Visit(x->extent);
str_ += ")\n";
DoIndent();
Visit(x->body);
}
void IrPrinter::Visit(const PolyFor *x) {
if (x->is_parallel()) {
str_ += "parallel poly_for (";
} else {
str_ += "poly_for (";
}
Visit(x->iterator);
str_ += ", ";
Visit(x->init);
str_ += ", ";
Visit(x->condition);
str_ += ", ";
Visit(x->inc);
str_ += ")\n";
DoIndent();
Visit(x->body);
}
void IrPrinter::Visit(const IfThenElse *x) {
str_ += "if (";
Visit(x->condition);
str_ += ") {\n";
IncIndent();
DoIndent();
Visit(x->true_case);
DecIndent();
str_ += "\n";
DoIndent();
str_ += "}";
if (x->false_case.defined()) {
str_ += " else {\n";
IncIndent();
DoIndent();
Visit(x->false_case);
str_ += "\n";
DecIndent();
DoIndent();
str_ += "}";
}
}
void IrPrinter::Visit(const Block *x) {
str_ += "{\n";
IncIndent();
for (std::size_t i = 0; !x->stmts.empty() && i + 1 < x->stmts.size(); i++) {
DoIndent();
Visit(x->stmts[i]);
str_ += "\n";
}
if (!x->stmts.empty()) {
DoIndent();
Visit(x->stmts.back());
}
DecIndent();
str_ += "\n";
DoIndent();
str_ += "}";
}
void IrPrinter::Visit(const Call *x) {
str_ += x->name;
str_ += "(";
if (!x->read_args.empty()) {
for (std::size_t i = 0; i + 1 < x->read_args.size(); i++) {
Visit(x->read_args[i]);
str_ += ", ";
}
Visit(x->read_args.back());
}
if (!x->write_args.empty()) {
if (!x->read_args.empty()) str_ += ", ";
for (std::size_t i = 0; i + 1 < x->write_args.size(); i++) {
Visit(x->write_args[i]);
str_ += ", ";
}
Visit(x->write_args.back());
}
str_ += ")";
}
void IrPrinter::Visit(const Cast *x) {
str_ += x->type().to_string();
str_ += "(";
Visit(x->v());
str_ += ")";
}
void IrPrinter::Visit(const _Module_ *x) {}
void IrPrinter::Visit(const _Var_ *x) { str_ += x->name; }
void IrPrinter::Visit(const Alloc *x) {
auto *buffer = x->destination.As<ir::_Buffer_>();
CHECK(buffer);
str_ += "alloc(";
str_ += buffer->name;
str_ += ", ";
Visit(x->extents);
str_ += ")";
}
void IrPrinter::Visit(const Select *x) {
str_ += "select(";
Visit(x->condition);
str_ += ", ";
Visit(x->true_value);
str_ += ", ";
Visit(x->false_value);
str_ += ")";
}
void IrPrinter::Visit(const Load *x) {
if (x->is_addr_tensor()) {
auto *tensor = x->tensor.As<ir::_Tensor_>();
CHECK(tensor);
str_ += tensor->name;
} else if (x->is_addr_scalar()) {
Visit(x->tensor);
} else {
CINN_NOT_IMPLEMENTED
}
str_ += "[";
for (std::size_t i = 0; i + 1 < x->indices.size(); i++) {
Visit(x->indices[i]);
str_ += ", ";
}
if (!x->indices.empty()) Visit(x->indices.back());
str_ += "]";
}
void IrPrinter::Visit(const Store *x) {
if (x->is_addr_tensor()) {
auto *tensor_node = x->tensor.As<ir::_Tensor_>();
CHECK(tensor_node);
str_ += tensor_node->name;
} else if (x->is_addr_scalar()) {
Visit(x->tensor);
} else {
CINN_NOT_IMPLEMENTED
}
str_ += "[";
for (std::size_t i = 0; i + 1 < x->indices.size(); i++) {
Visit(x->indices[i]);
str_ += ", ";
}
if (!x->indices.empty()) Visit(x->indices.back());
str_ += "] = ";
Visit(x->value);
}
void IrPrinter::Visit(const Free *x) {
auto *buffer = x->destination.As<ir::_Buffer_>();
CHECK(buffer);
str_ += "free(";
str_ += buffer->name;
str_ += ")";
}
void IrPrinter::DoIndent() { str_ += std::string(indent_, ' '); }
void IrPrinter::IncIndent() { indent_ += indent_unit; }
void IrPrinter::DecIndent() { indent_ -= indent_unit; }
void IrPrinter::Visit(const _Buffer_ *x) {
std::vector<std::string> dim_names;
std::transform(x->shape.begin(),
x->shape.end(),
std::back_inserter(dim_names),
[&](const Expr &x) { return utils::GetStreamCnt(x); });
str_ += "_Buffer_<";
str_ += x->type().to_string();
str_ += ": ";
str_ += utils::Join(dim_names, ",");
str_ += ">(";
str_ += x->name;
str_ += ")";
}
void IrPrinter::Visit(const _Tensor_ *x) {
str_ += "Tensor(";
str_ += x->name;
str_ += ", ";
str_ += "[";
if (!x->shape.empty()) {
for (std::size_t i = 0; i + 1 < x->shape.size(); i++) {
Visit(x->shape[i]);
str_ += ",";
}
Visit(x->shape.back());
}
str_ += "])";
}
void IrPrinter::Visit(const _LoweredFunc_ *f) {
str_ += "function ";
str_ += f->name;
str_ += " ";
std::vector<std::string> arg_names;
for (auto &arg : f->args) {
arg_names.push_back(arg.name());
}
str_ += "(";
str_ += utils::Join(arg_names, ", ");
str_ += ")\n";
Visit(f->body);
}
void IrPrinter::Visit(const Let *f) {
CHECK(f->type().valid());
str_ += f->type().to_string();
str_ += " ";
Visit(f->symbol);
if (f->body.defined()) {
str_ += " = ";
Visit(f->body);
}
}
void IrPrinter::Visit(const Reduce *f) {
str_ += "Reduce(";
switch (f->reduce_type) {
case Reduce::ReduceType::kSum:
str_ += "sum";
break;
case Reduce::ReduceType::kSub:
str_ += "sub";
break;
case Reduce::ReduceType::kDiv:
str_ += "Div";
break;
case Reduce::ReduceType::kMul:
str_ += "Mul";
break;
case Reduce::ReduceType::kMax:
str_ += "Max";
break;
case Reduce::ReduceType::kMin:
str_ += "Min";
break;
case Reduce::ReduceType::kAll:
str_ += "&&";
break;
case Reduce::ReduceType::kAny:
str_ += "||";
break;
}
str_ += ", ";
Visit(f->body);
str_ += ",";
Visit(f->init);
str_ += ")";
}
void IrPrinter::Visit(const Ramp *x) {
str_ += "Ramp(";
Visit(x->base);
str_ += ",";
Visit(x->stride);
str_ += ",";
str_ += std::to_string(x->lanes);
str_ += ")";
}
void IrPrinter::Visit(const Broadcast *x) {
str_ += "Broadcast(";
Visit(x->value);
str_ += ",";
str_ += std::to_string(x->lanes);
str_ += ")";
}
void IrPrinter::Visit(const FracOp *x) {
str_ += "(";
Visit(x->a());
str_ += " / ";
Visit(x->b());
str_ += ")";
}
void IrPrinter::Visit(const Product *x) {
str_ += "(";
for (std::size_t i = 0; i + 1 < x->operands().size(); i++) {
Visit(x->operand(i));
str_ += " * ";
}
if (!x->operands().empty()) Visit(x->operands().back());
str_ += ")";
}
void IrPrinter::Visit(const Sum *x) {
str_ += "(";
for (std::size_t i = 0; i + 1 < x->operands().size(); i++) {
Visit(x->operand(i));
str_ += " + ";
}
if (!x->operands().empty()) Visit(x->operands().back());
str_ += ")";
}
void IrPrinter::Visit(const PrimitiveNode *x) {
str_ += x->name;
str_ += "(";
std::vector<std::string> args_repr;
for (auto &args : x->arguments) {
std::vector<std::string> arg_repr;
for (auto &arg : args) {
arg_repr.push_back(utils::GetStreamCnt(arg));
}
args_repr.push_back(utils::Join(arg_repr, ","));
}
str_ += utils::Join(args_repr, ",");
str_ += ")";
}
void IrPrinter::Visit(const _BufferRange_ *x) {
auto *buffer = x->buffer.As<ir::_Buffer_>();
CHECK(buffer);
str_ += buffer->name;
str_ += "[";
for (std::size_t i = 0; i < x->ranges.size(); i++) {
if (i) str_ += ", ";
auto &range = x->ranges[i];
str_ += range->name;
str_ += "(";
if (range->lower_bound.defined()) {
Visit(range->lower_bound);
str_ += ":";
} else {
str_ += "undefined:";
}
if (range->upper_bound.defined()) {
Visit(range->upper_bound);
} else {
str_ += "undefined";
}
str_ += ")";
}
str_ += "]";
}
void IrPrinter::Visit(const ScheduleBlock *x) {}
void IrPrinter::Visit(const ScheduleBlockRealize *x) {
auto *schedule_block = x->schedule_block.As<ScheduleBlock>();
str_ += "ScheduleBlock(";
str_ += schedule_block->name;
str_ += ")\n";
DoIndent();
str_ += "{\n";
// print block vars and bindings
auto iter_vars = schedule_block->iter_vars;
auto iter_values = x->iter_values;
CHECK_EQ(iter_vars.size(), iter_values.size());
IncIndent();
if (!iter_vars.empty()) DoIndent();
for (std::size_t i = 0; i < iter_vars.size(); i++) {
if (i) str_ += ", ";
str_ += iter_vars[i]->name;
}
if (!iter_vars.empty()) str_ += " = axis.bind(";
for (std::size_t i = 0; i < iter_values.size(); i++) {
if (i) str_ += ", ";
Visit(iter_values[i]);
}
if (!iter_vars.empty()) str_ += ")\n";
// print block body
if (!schedule_block->read_buffers.empty()) {
DoIndent();
str_ += "read_buffers(";
auto &read_buffers = schedule_block->read_buffers;
for (std::size_t i = 0; i < read_buffers.size(); i++) {
if (i) str_ += ", ";
Visit(read_buffers[i]);
}
str_ += ")\n";
}
if (!schedule_block->write_buffers.empty()) {
DoIndent();
str_ += "write_buffers(";
auto &write_buffers = schedule_block->write_buffers;
for (std::size_t i = 0; i < write_buffers.size(); i++) {
if (i) str_ += ", ";
Visit(write_buffers[i]);
}
str_ += ")\n";
}
if (!schedule_block->attrs.empty()) {
DoIndent();
str_ += "attrs(";
bool comma = false;
for (auto &&kv : schedule_block->attrs) {
if (comma) str_ += ", ";
str_ += kv.first;
str_ += ":";
absl::visit(
[this](auto &&arg) {
std::ostringstream ss;
ss << arg;
this->str_ += ss.str();
},
kv.second);
comma = true;
}
str_ += ")\n";
}
DoIndent();
Visit(schedule_block->body);
str_ += "\n";
DecIndent();
DoIndent();
str_ += "}";
}
void IrPrinter::Visit(const IntrinsicOp *x) {
switch (x->getKind()) {
#define __(op__) \
case IntrinsicKind::k##op__: \
Visit(llvm::dyn_cast<intrinsics::op__>(x)); \
break;
INTRINSIC_KIND_FOR_EACH(__)
#undef __
}
}
void IrPrinter::Visit(const intrinsics::BufferGetDataHandle *x) {
str_ += runtime::intrinsic::buffer_get_data_handle;
Visit(x->buffer);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::BufferGetDataConstHandle *x) {
str_ += runtime::intrinsic::buffer_get_data_const_handle;
Visit(x->buffer);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::PodValueToX *x) {
str_ += "pod_value_to_";
str_ += x->GetOutputType(0).to_string();
str_ += "(";
Visit(x->pod_value_ptr);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::BufferCreate *x) {
str_ += runtime::intrinsic::buffer_create;
str_ += "()";
}
void IrPrinter::Visit(const intrinsics::GetAddr *x) {
str_ += "get_addr(";
Visit(x->data);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::ArgsConstruct *x) {
str_ += runtime::intrinsic::args_construct_repr;
str_ += "(";
Visit(std::vector<Expr>(x->args.begin(), x->args.end()));
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::BuiltinIntrin *x) {
str_ += runtime::intrinsic::builtin_intrin_repr;
str_ += "_";
str_ += x->name;
str_ += "(";
if (!x->args.empty()) {
for (std::size_t i = 0; i + 1 < x->args.size(); i++) {
Visit(x->args[i]);
str_ += ", ";
}
Visit(x->args.back());
}
str_ += ")";
}
std::ostream &operator<<(std::ostream &os, Expr a) {
std::stringstream ss;
IrPrinter printer(ss);
printer.Print(a);
os << ss.str();
return os;
}
std::ostream &operator<<(std::ostream &os, const std::vector<Expr> &a) {
std::stringstream ss;
IrPrinter printer(ss);
printer.Print(a);
os << ss.str();
return os;
}
std::ostream &operator<<(std::ostream &os, const ir::Module &m) {
os << "Module " << m->name << " {\n\n";
for (auto &fn : m->functions) {
os << fn << '\n';
}
os << "\n\n}";
return os;
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace cinn {
namespace lang {
class LoweredFunc;
} // namespace lang
namespace ir {
class Module;
struct IrPrinter : public IRVisitorRequireReImpl<void> {
explicit IrPrinter(std::ostream &os) : os_(os), str_("") {}
//! Emit an expression on the output stream.
void Print(const Expr &e);
//! Emit a expression list with , splitted.
void Print(const std::vector<Expr> &exprs,
const std::string &splitter = ", ");
//! Emit a binary operator
template <typename IRN>
void PrintBinaryOp(const std::string &op, const BinaryOpNode<IRN> *x);
//! Prefix the current line with `indent_` spaces.
void DoIndent();
//! Increase the indent size.
void IncIndent();
//! Decrease the indent size.
void DecIndent();
std::ostream &os() { return os_; }
void Visit(const Expr &x) { IRVisitorRequireReImpl::Visit(&x); }
void Visit(const std::vector<Expr> &exprs,
const std::string &splitter = ", ") {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
Visit(exprs[i]);
str_ += splitter;
}
if (!exprs.empty()) Visit(exprs.back());
}
#define __(op__) void Visit(const op__ *x) override;
NODETY_FORALL(__)
#undef __
#define __(op__) virtual void Visit(const intrinsics::op__ *x);
INTRINSIC_KIND_FOR_EACH(__)
#undef __
protected:
std::string str_;
private:
std::ostream &os_;
uint16_t indent_{};
const int indent_unit{2};
};
std::ostream &operator<<(std::ostream &os, Expr a);
std::ostream &operator<<(std::ostream &os, const std::vector<Expr> &a);
std::ostream &operator<<(std::ostream &os, const Module &m);
template <typename IRN>
void IrPrinter::PrintBinaryOp(const std::string &op,
const BinaryOpNode<IRN> *x) {
str_ += "(";
Visit(x->a());
str_ += " ";
str_ += op;
str_ += " ";
Visit(x->b());
str_ += ")";
}
} // namespace ir
} // namespace cinn
......@@ -12,17 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/ir_replace.h"
#include "paddle/cinn/ir/utils/ir_replace.h"
#include <set>
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace optim {
namespace ir {
namespace ir_utils {
using utils::GetStreamCnt;
namespace {
......@@ -42,14 +43,14 @@ struct IrReplaceMutator : ir::IRMutator<Expr*> {
void Visit(const ir::_Var_* op, Expr* expr) override {
if (op->node_type() == from_->node_type() &&
from_repr_ == GetStreamCnt(*expr)) {
*expr = optim::IRCopy(to_);
*expr = ir::ir_utils::IRCopy(to_);
}
}
void Visit(const ir::Broadcast* op, Expr* expr) override {
if (op->node_type() == from_->node_type() &&
from_repr_ == GetStreamCnt(*expr)) {
*expr = optim::IRCopy(to_);
*expr = ir::ir_utils::IRCopy(to_);
}
}
......@@ -65,5 +66,6 @@ void IrReplace(ir::Expr* expr, ir::Expr from, ir::Expr to) {
IrReplaceMutator(from, to)(expr);
}
} // namespace optim
} // namespace ir_utils
} // namespace ir
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment