Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <numeric> #include <numeric>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/ir_printer.h"
#include "test/cpp/cinn/program_builder.h" #include "test/cpp/cinn/program_builder.h"
namespace cinn { namespace cinn {
......
...@@ -26,10 +26,11 @@ ...@@ -26,10 +26,11 @@
#include "paddle/cinn/common/target.h" #include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
...@@ -49,7 +50,12 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ...@@ -49,7 +50,12 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr); ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr);
// Check the schedule block to be inlined is not a reduce tensor. // Check the schedule block to be inlined is not a reduce tensor.
std::set<ir::Expr> find_store = ir::CollectIRNodesWithoutTensor( for (const ir::Var& iter_var : sche_block->iter_vars) {
if (iter_var->is_reduce_axis) {
return false;
}
}
std::set<ir::Expr> find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Store>(); }); compute_body, [&](const Expr* x) { return x->As<ir::Store>(); });
if (find_store.size() != 1UL) { if (find_store.size() != 1UL) {
return false; return false;
...@@ -69,6 +75,29 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ...@@ -69,6 +75,29 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
return false; return false;
} }
// the xxx_reduce_init block cannot be inlined.
if (ir::IsReduceInitTensorName(tensor->name)) {
return false;
}
// Skip external calls
std::vector<ir::Expr> consumers =
ir::GetConsumers(sche_block_realize_expr, root);
for (const ir::Expr& consumer : consumers) {
std::set<ir::Expr> find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
consumer.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body,
[&](const ir::Expr* x) {
return x->As<ir::Load>() &&
x->As<ir::Load>()->tensor.as_tensor_ref()->name ==
tensor->name;
});
if (find_load.empty()) {
return false;
}
}
// write_buffers.size() = 1 and read_buffers is empty, means const // write_buffers.size() = 1 and read_buffers is empty, means const
// we can inline to consumer // we can inline to consumer
if (sche_block->read_buffers.empty()) { if (sche_block->read_buffers.empty()) {
...@@ -76,17 +105,19 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ...@@ -76,17 +105,19 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
} }
// Check this schedule block is the only writer of the tensor. // Check this schedule block is the only writer of the tensor.
find_store = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { find_store =
return x->As<ir::Store>() && ir::ir_utils::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
(x->As<ir::Store>()->tensor).as_tensor_ref()->name == tensor->name; return x->As<ir::Store>() &&
}); (x->As<ir::Store>()->tensor).as_tensor_ref()->name ==
tensor->name;
});
if (find_store.size() != 1UL) { if (find_store.size() != 1UL) {
return false; return false;
} }
// Check there is no overlap between the buffers the schedule block reads and // Check there is no overlap between the buffers the schedule block reads and
// writes. // writes.
std::set<ir::Expr> find_load = std::set<ir::Expr> find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { compute_body, [&](const Expr* x) {
return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr; return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr;
}); });
if (!find_load.empty()) { if (!find_load.empty()) {
......
...@@ -63,7 +63,6 @@ class AutoInline : public AutoGenRule { ...@@ -63,7 +63,6 @@ class AutoInline : public AutoGenRule {
std::vector<SearchState> ApplyOnBlock(SearchState state, std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override; const std::string& block_name) override;
private:
void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); // NOLINT void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); // NOLINT
private: private:
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/cinn.h" #include "paddle/cinn/cinn.h"
...@@ -30,9 +31,9 @@ ...@@ -30,9 +31,9 @@
#include "paddle/cinn/ir/function_base.h" #include "paddle/cinn/ir/function_base.h"
#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h" #include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h" #include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h" #include "paddle/cinn/poly/stage.h"
...@@ -59,16 +60,13 @@ TEST(AutoInline, SingleLoopInline) { ...@@ -59,16 +60,13 @@ TEST(AutoInline, SingleLoopInline) {
ir::Tensor C = Compute( ir::Tensor C = Compute(
{M}, [&](Var i) { return B(i) + ir::Expr(1.f); }, "C"); {M}, [&](Var i) { return B(i) + ir::Expr(1.f); }, "C");
poly::StageMap stages = CreateStages({A, B, C}); ast_gen_ius::TensorGroup tensor_group({A, B, C});
std::vector<ir::LoweredFunc> funcs = std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestAutoInline_SingleLoopInline", lang::LowerToAstVec("TestAutoInline_SingleLoopInline",
stages,
{A, C}, {A, C},
{}, &tensor_Group,
{}, target);
nullptr,
target,
true);
VLOG(6) << "Expr after lowering:"; VLOG(6) << "Expr after lowering:";
VLOG(6) << funcs[0]->body; VLOG(6) << funcs[0]->body;
...@@ -161,14 +159,14 @@ TEST(AutoInline, AddReluInline) { ...@@ -161,14 +159,14 @@ TEST(AutoInline, AddReluInline) {
"inferdtype"); "inferdtype");
const auto& shape_dict = graph->GetAttrs< const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>( auto op_lowerer =
dtype_dict, shape_dict, target); hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target);
EXPECT_EQ(graph->fusion_groups.size(), 1UL); EXPECT_EQ(graph->fusion_groups.size(), 1UL);
std::vector<ir::LoweredFunc> funcs = std::vector<ir::LoweredFunc> funcs =
op_lowerer->Lower(graph->fusion_groups[0], op_lowerer.Lower(graph->fusion_groups[0],
/*apply_op_schedule = */ false, /*apply_op_schedule = */ false,
/*apply_group_schedule=*/false); /*apply_group_schedule=*/false);
VLOG(6) << "Expr before auto inline: " << funcs[0]->body; VLOG(6) << "Expr before auto inline: " << funcs[0]->body;
......
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
#include <cstdlib> #include <cstdlib>
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
...@@ -56,7 +56,7 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const { ...@@ -56,7 +56,7 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
return false; return false;
}; };
auto find_target_exprs = ir::CollectIRNodesWithoutTensor( auto find_target_exprs = ir::ir_utils::CollectIRNodesWithoutTensor(
schedule_block->body, schedule_block->body,
[&has_reduce_iter, &has_nonserial_loop](const Expr* x) { [&has_reduce_iter, &has_nonserial_loop](const Expr* x) {
return has_reduce_iter(x) || has_nonserial_loop(x); return has_reduce_iter(x) || has_nonserial_loop(x);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h" #include "paddle/cinn/cinn.h"
#include "paddle/cinn/lang/lower.h" #include "paddle/cinn/lang/lower.h"
...@@ -38,9 +39,9 @@ TEST(AutoUnroll, Init) { ...@@ -38,9 +39,9 @@ TEST(AutoUnroll, Init) {
#else #else
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
auto stages = CreateStages({C}); ast_gen_ius::TensorGroup tensor_group({C});
auto funcs = cinn::lang::LowerVec( auto funcs =
"test_init", stages, {A, B, C}, {}, {}, nullptr, target, true); cinn::lang::LowerToAstVec("test_init", {A, B, C}, &tensor_group, target);
auto ast_expr = funcs[0]->body; auto ast_expr = funcs[0]->body;
ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr})); ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr}));
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "test/cpp/cinn/program_builder.h" #include "test/cpp/cinn/program_builder.h"
namespace cinn { namespace cinn {
......
...@@ -29,11 +29,11 @@ ...@@ -29,11 +29,11 @@
#include "paddle/cinn/ir/buffer.h" #include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
......
...@@ -21,15 +21,16 @@ ...@@ -21,15 +21,16 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/cinn.h" #include "paddle/cinn/cinn.h"
#include "paddle/cinn/frontend/syntax.h" #include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h" #include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h" #include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h" #include "paddle/cinn/poly/stage.h"
...@@ -106,16 +107,9 @@ TEST(MultiLevelTile, SimpleLoops) { ...@@ -106,16 +107,9 @@ TEST(MultiLevelTile, SimpleLoops) {
ir::Tensor C = Compute( ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C}); ast_gen_ius::TensorGroup tensor_group({C});
std::vector<ir::LoweredFunc> funcs = std::vector<ir::LoweredFunc> funcs = lang::LowerToAstVec(
lang::LowerVec("TestMultiLevelTile_SimpleLoops", "TestMultiLevelTile_SimpleLoops", {C}, &tensor_group, target);
stages,
{C},
{},
{},
nullptr,
target,
true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: "; VLOG(6) << "Expr before MultiLevelTiling: ";
...@@ -261,7 +255,7 @@ TEST_F(TestMultiLevelTiling, Matmul) { ...@@ -261,7 +255,7 @@ TEST_F(TestMultiLevelTiling, Matmul) {
{ {
i0, i1 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))) i0, i1 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)))
{ {
temp_matmul_out__reduce_init[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = 0.00000000f temp_matmul_out__reduce_init[i0, i1] = 0.00000000f
} }
} }
} }
...@@ -308,10 +302,10 @@ TEST_F(TestMultiLevelTiling, Matmul) { ...@@ -308,10 +302,10 @@ TEST_F(TestMultiLevelTiling, Matmul) {
ScheduleBlock(temp_matmul_out_local_temp_buffer) ScheduleBlock(temp_matmul_out_local_temp_buffer)
{ {
i0_0, i1_0, i2 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2))) i0_0, i1_0, i2 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)))
read_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)], _X[i(undefined:undefined), reduce_k(undefined:undefined)], _Y[reduce_k(undefined:undefined), j(undefined:undefined)]) read_buffers(_temp_matmul_out[i0_0(0:32), i1_0(0:32)], _X[i0_0(0:32), i2(0:32)], _Y[i2(0:32), i1_0(0:32)])
write_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)]) write_buffers(_temp_matmul_out[i0_0(0:32), i1_0(0:32)])
{ {
temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = (temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] + (X_reshape_shared_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2))] * Y_reshape_shared_temp_buffer[((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)), ((32 * j_1) + ((32 * j_2) + j_3))])) temp_matmul_out_local_temp_buffer[i0_0, i1_0] = (temp_matmul_out_local_temp_buffer[i0_0, i1_0] + (X_reshape_shared_temp_buffer[i0_0, i2] * Y_reshape_shared_temp_buffer[i2, i1_0]))
} }
} }
} }
...@@ -453,7 +447,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) { ...@@ -453,7 +447,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{ {
i0, i1, i2, i3 = axis.bind(i, j, k, a) i0, i1, i2, i3 = axis.bind(i, j, k, a)
{ {
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f) pad_temp_0[i0, i1, i2, i3] = select(((i3 < (1 + 16)) and ((i3 >= 1) and ((i2 < (1 + 16)) and (i2 >= 1)))), input[i0, i1, (i2 - 1), (i3 - 1)], -3.40282347e+38f)
} }
} }
} }
...@@ -477,7 +471,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) { ...@@ -477,7 +471,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{ {
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)) i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{ {
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f var_0__reduce_init[i0_0, i1_0, i2_0, i3_0] = -3.40282347e+38f
} }
} }
} }
...@@ -511,10 +505,10 @@ TEST_F(TestMultiLevelTiling, Pool2d) { ...@@ -511,10 +505,10 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
ScheduleBlock(var_0_local_temp_buffer) ScheduleBlock(var_0_local_temp_buffer)
{ {
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0) i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)]) read_buffers(_var_0[i0_1(0:2), i1_1(0:8), i2_1(0:8), i3_1(0:8)], _pad_temp_0[i0_1(0:2), i1_1(0:8)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)]) write_buffers(_var_0[i0_1(0:2), i1_1(0:8), i2_1(0:8), i3_1(0:8)])
{ {
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))]) var_0_local_temp_buffer[i0_1, i1_1, i2_1, i3_1] = cinn_max(var_0_local_temp_buffer[i0_1, i1_1, i2_1, i3_1], pad_temp_0_shared_temp_buffer[i0_1, i1_1, ((2 * i2_1) + i4), ((2 * i3_1) + i5)])
} }
} }
} }
...@@ -533,7 +527,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) { ...@@ -533,7 +527,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{ {
ScheduleBlock(var_0) ScheduleBlock(var_0)
{ {
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0)) v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((i_0_j_0_k_0_a_0_fused % 4) + (4 * ((i_j_k_a_fused / 2) % 2))) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0) attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{ {
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3] var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h"
#include <glog/logging.h>
#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
namespace cinn {
namespace auto_schedule {
bool ReductionFactoring::CanApply(const std::string& block_name,
ir::IRSchedule* ir_schedule) const {
ir::Expr block_expr = ir_schedule->GetBlock(block_name);
ir::ScheduleBlockRealize* block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(block_realize);
ir::ScheduleBlock* sch_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(sch_block);
AnalyzeScheduleBlockReadWriteBuffer(sch_block);
// 1. The block must have write buffer
if (sch_block->write_buffers.empty()) {
return false;
}
// 2. The block must have at least one reduce axis
const std::vector<ir::Var>& iter_vars = sch_block->iter_vars;
bool find_reduce_axis = false;
for (int i = 0; i < iter_vars.size(); ++i) {
if (iter_vars[i]->is_reduce_axis) {
find_reduce_axis = true;
break;
}
}
if (!find_reduce_axis) {
return false;
}
// 3. Each loop's body only contains one sub loop or block, except reduce_init
// block
std::vector<ir::Expr> loops = ir_schedule->GetLoops(block_name);
for (const ir::Expr& loop : loops) {
const ir::Expr& body = loop.As<ir::For>()->body;
if (body.As<ir::Block>()) {
if (body.As<ir::Block>()->stmts.size() == 1) {
if (body.As<ir::Block>()->stmts[0].As<ir::For>() == nullptr &&
body.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>() ==
nullptr) {
return false;
}
} else if (body.As<ir::Block>()->stmts.size() == 2) {
if (body.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>() ==
nullptr ||
!ir::IsReduceInitTensorName(
GetBlockName(body.As<ir::Block>()->stmts[0]))) {
return false;
}
if (body.As<ir::Block>()->stmts[1].As<ir::For>() == nullptr &&
body.As<ir::Block>()->stmts[1].As<ir::ScheduleBlockRealize>() ==
nullptr) {
return false;
}
} else {
return false;
}
} else if (body.As<ir::For>() || body.As<ir::ScheduleBlockRealize>()) {
continue;
} else {
return false;
}
}
return true;
}
RuleApplyType ReductionFactoring::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
return this->CanApply(block_name, &(state->ir_schedule))
? RuleApplyType::kApply
: RuleApplyType::kCannotApply;
}
std::vector<SearchState> ReductionFactoring::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy();
Apply(block_name, &(new_state->ir_schedule));
return {new_state};
}
void ReductionFactoring::Apply(const std::string& block_name,
ir::IRSchedule* ir_schedule) {
ir::Expr block = ir_schedule->GetBlock(block_name);
std::vector<ir::Expr> all_loops = ir_schedule->GetLoops(block_name);
std::vector<ir::Expr> new_loop_order;
size_t num_spatial_loops = 0;
size_t num_reduction_loops = 0;
// 1. Add all spatial loops
std::unordered_set<std::string> reduce_loop_var_names =
GetReduceLoopVarNames(block);
for (const ir::Expr& expr : all_loops) {
if (reduce_loop_var_names.count(expr.As<ir::For>()->loop_var->name) == 0) {
new_loop_order.push_back(expr);
++num_spatial_loops;
}
}
// 2. Add all reduction loops
for (const ir::Expr& expr : all_loops) {
if (reduce_loop_var_names.count(expr.As<ir::For>()->loop_var->name) > 0) {
new_loop_order.push_back(expr);
++num_reduction_loops;
}
}
if (num_reduction_loops == 0) {
return;
}
// 3. Reorder if new_loop_order differs from the original order
CHECK_EQ(all_loops.size(), new_loop_order.size());
for (int i = 0; i < all_loops.size(); ++i) {
if (all_loops[i].As<ir::For>()->loop_var->name !=
new_loop_order[i].As<ir::For>()->loop_var->name) {
ir_schedule->Reorder(new_loop_order);
break;
}
}
// 4. Fuse all reduction loops
ir::Expr fused_reduce_loop;
VLOG(6) << "before Fuse: " << ir_schedule->GetModule().GetExprs()[0];
if (num_reduction_loops > 1) {
std::vector<int> reduction_loop_indices;
for (int i = num_spatial_loops; i < all_loops.size(); ++i) {
reduction_loop_indices.push_back(i);
}
CHECK_EQ(reduction_loop_indices.size(), num_reduction_loops);
fused_reduce_loop = ir_schedule->Fuse(block_name, reduction_loop_indices);
} else {
all_loops = ir_schedule->GetLoops(block_name);
fused_reduce_loop = all_loops.back();
}
// 5. Split the reduction loop into 2 part
VLOG(6) << "before Split: " << ir_schedule->GetModule().GetExprs()[0];
int factor = 1;
int max_factor = 1024;
int extent = ir::GetLoopExtent(fused_reduce_loop);
for (int i = max_factor; i >= 1; --i) {
if (extent % i == 0) {
factor = i;
break;
}
}
std::vector<cinn::ir::Expr> splited_reduction_loops =
ir_schedule->Split(fused_reduce_loop, {factor, -1});
// 6. Apply FactorizeReduction
VLOG(6) << "before FactorizeReduction: "
<< ir_schedule->GetModule().GetExprs()[0];
ir_schedule->FactorizeReduction(splited_reduction_loops[0],
num_spatial_loops);
VLOG(6) << "after FactorizeReduction: "
<< ir_schedule->GetModule().GetExprs()[0];
// 7. Loop fusion and cross thread reduction
std::vector<ir::Expr> rb_loops = ir_schedule->GetLoops(block_name);
ir::Expr rf_block = ir_schedule->GetBlock(block_name + "_rf");
ir_schedule->SimpleComputeAt(rf_block, rb_loops.back());
rb_loops = ir_schedule->GetLoops(block_name);
ir::Expr rf_init_block =
ir_schedule->GetBlock(block_name + "_rf__reduce_init");
ir_schedule->SimpleComputeAt(rf_init_block, rb_loops.back());
if (*target_ == common::DefaultNVGPUTarget()) {
rb_loops = ir_schedule->GetLoops(block_name);
rf_block = ir_schedule->GetBlock(block_name + "_rf");
ir_schedule->Bind(rb_loops.back(), "threadIdx.x");
ir_schedule->SetBuffer(rf_block, "shared");
}
VLOG(6) << "Loop fusion and cross thread reduction: "
<< ir_schedule->GetModule().GetExprs()[0];
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
class ReductionFactoring : public AutoGenRule {
public:
explicit ReductionFactoring(const common::Target& target)
: AutoGenRule(target) {}
~ReductionFactoring() = default;
// In the future, we will no longer use this interface.
RuleApplyType Init(ir::IRSchedule* init_schedule) override {
return RuleApplyType::kCannotApply;
}
// In the future, we will no longer use this interface.
void Apply(int index) override {
LOG(FATAL) << "This is a deprecated interface, please do not use it.";
return;
}
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::string GetRuleName() const override { return "ReductionFactoring"; }
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
void Apply(const std::string& block_name, ir::IRSchedule* ir_schedule);
private:
bool CanApply(const std::string& block_name,
ir::IRSchedule* ir_schedule) const;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cmath>
#include <functional>
#include <numeric>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "test/cpp/cinn/concrete_program_builder.h"
PD_DECLARE_bool(cinn_new_group_scheduler);
namespace cinn {
namespace auto_schedule {
class TestReductionFactoring : public TestAutoGenRuleBase {
public:
std::vector<std::string> default_input_names = {"X"};
std::vector<std::string> default_output_names = {"out"};
void TestApplyOnReduce(const std::vector<int>& shape,
const std::vector<int>& reduce_dim,
const std::string& block_name,
const std::string& expected_ir) {
Initialize(common::DefaultNVGPUTarget());
// In order to forcibly use the most basic Compute of reduction
FLAGS_cinn_new_group_scheduler = 1;
auto test_program = tests::ReduceBuilder().Build(
{{"X", shape}}, {{"reduce_dim", reduce_dim}});
// construct input parameter
ir::IRSchedule ir_schedule = MakeIRSchedule(test_program);
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];
// apply
ReductionFactoring reduction_factoring(target_);
ASSERT_EQ(reduction_factoring.AnalyseApplyType(state, block_name),
RuleApplyType::kApply);
auto result = reduction_factoring.ApplyOnBlock(state, block_name)[0];
std::vector<ir::Expr> exprs = result->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ir;
ir << exprs[0];
VLOG(6) << "ReductionFactoring applied Expr: " << exprs[0];
// check
const std::vector<ir::Expr>& blocks = ir_schedule.GetAllBlocks();
CHECK_EQ(blocks.size(), 2UL);
CHECK_EQ(ir.str(), expected_ir);
}
};
TEST_F(TestReductionFactoring, AnalyseApplyType) {
Context::Global().ResetNameId();
Initialize(common::DefaultNVGPUTarget());
auto test_program =
tests::OpBuilder("elementwise_add").Build({{"X", {4, 5}}, {"Y", {4, 5}}});
ir::IRSchedule ir_schedule = MakeIRSchedule(test_program);
VLOG(6) << "Original Expr:\n" << ir_schedule.GetModule().GetExprs()[0];
SearchState state(ir_schedule, 0, {});
ReductionFactoring reduction_factoring(target_);
EXPECT_EQ(reduction_factoring.AnalyseApplyType(state, "var_1"),
RuleApplyType::kCannotApply);
}
#ifdef CINN_WITH_CUDA
TEST_F(TestReductionFactoring, ApplyOnBlock1ReduceDim) {
Context::Global().ResetNameId();
std::string expected_ir = R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_0, 0, 64)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_0] = 0.00000000f
}
{
serial for (reduce_k_0_1, 0, 1)
{
ScheduleBlock(var_0_rf)
{
vreduce_k_0_0, i0_0, vreduce_k_0_1 = axis.bind(reduce_k_0_0, i, reduce_k_0_1)
var_0_rf[i0_0, vreduce_k_0_0] = (var_0_rf[i0_0, vreduce_k_0_0] + X[i0_0, (vreduce_k_0_0 + vreduce_k_0_1)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_0])
}
}
}
}
}
}
}
})";
TestApplyOnReduce({32, 64}, {1}, "var_0", expected_ir);
}
TEST_F(TestReductionFactoring, ApplyOnBlock2ReduceDim) {
Context::Global().ResetNameId();
std::string expected_ir = R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_fused, 0, 1024)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_fused] = 0.00000000f
}
{
serial for (reduce_k_0_reduce_k_1_fused_0, 0, 8)
{
ScheduleBlock(var_0_rf)
{
vreduce_k_0_reduce_k_1_fused, i0_0, vreduce_k_0_reduce_k_1_fused_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i, reduce_k_0_reduce_k_1_fused_0)
var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] + X[i0_0, (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) / 128), (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) % 128)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused])
}
}
}
}
}
}
}
})";
TestApplyOnReduce({32, 64, 128}, {1, 2}, "var_0", expected_ir);
}
TEST_F(TestReductionFactoring, ApplyOnBlock3ReduceDim) {
Context::Global().ResetNameId();
std::string expected_ir = R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 1024)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = 0.00000000f
}
{
serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused_0, 0, 256)
{
ScheduleBlock(var_0_rf)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i, reduce_k_0_reduce_k_1_reduce_k_2_fused_0)
var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] + X[i0_0, ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) / 64), ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) % 64), (((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) % 64)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused])
}
}
}
}
}
}
}
})";
TestApplyOnReduce({32, 64, 64, 64}, {1, 2, 3}, "var_0", expected_ir);
}
#endif
} // namespace auto_schedule
} // namespace cinn
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/cinn.h" #include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir.h"
...@@ -52,9 +53,9 @@ TEST(SkipRule, Basic) { ...@@ -52,9 +53,9 @@ TEST(SkipRule, Basic) {
ir::Tensor C = Compute( ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C}); ast_gen_ius::TensorGroup tensor_group({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec( std::vector<ir::LoweredFunc> funcs =
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); lang::LowerToAstVec("TestSkipRule_Basic", {C}, &tensor_group, target);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: "; VLOG(6) << "Expr before SkipRule: ";
...@@ -101,9 +102,9 @@ TEST(SkipRule, ApplyOnSpecificBlock) { ...@@ -101,9 +102,9 @@ TEST(SkipRule, ApplyOnSpecificBlock) {
ir::Tensor C = Compute( ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C}); ast_gen_ius::TensorGroup tensor_group({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec( std::vector<ir::LoweredFunc> funcs =
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); lang::LowerToAstVec("TestSkipRule_Basic", {C}, &tensor_group, target);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: "; VLOG(6) << "Expr before SkipRule: ";
......
...@@ -61,12 +61,14 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule( ...@@ -61,12 +61,14 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
"inferdtype"); "inferdtype");
auto& shape_dict = graph->GetMutableAttrs< auto& shape_dict = graph->GetMutableAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_); auto op_lowerer =
hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target_);
lowered_funcs_ = lowered_funcs_ =
op_lowerer.Lower(graph->fusion_groups.front(), op_lowerer.Lower(graph->fusion_groups.front(),
/*apply_op_schedule = */ apply_manual_schedule, /*apply_op_schedule = */ apply_manual_schedule,
/*apply_group_schedule = */ apply_manual_schedule); /*apply_group_schedule = */ apply_manual_schedule,
/*apply_pass = */ apply_manual_schedule);
CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty"; CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty";
std::vector<Expr> bodys; std::vector<Expr> bodys;
......
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
#include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/runtime/flags.h"
DECLARE_bool(auto_schedule_use_cost_model); PD_DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
......
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
#include <vector> #include <vector>
#include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/utils/functional.h" #include "paddle/cinn/utils/functional.h"
#include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/string.h"
...@@ -133,11 +133,10 @@ bool SearchStateEqual::operator()(const SearchState& lhs, ...@@ -133,11 +133,10 @@ bool SearchStateEqual::operator()(const SearchState& lhs,
// compare exprs size firstly // compare exprs size firstly
if (lhs_exprs.size() != rhs_exprs.size()) return false; if (lhs_exprs.size() != rhs_exprs.size()) return false;
// compare every expr one by one with ir::IrEqualVisitor // compare every expr one by one with ir::ir_utils::IrEqualVisitor
for (int i = 0; i < lhs_exprs.size(); ++i) { for (int i = 0; i < lhs_exprs.size(); ++i) {
ir::IrEqualVisitor compartor( if (!ir::ir_utils::IRCompare(lhs_exprs[i], rhs_exprs[i], true))
/*allow_name_suffix_diff=*/true); // ignore suffix difference in name return false;
if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false;
} }
return true; return true;
} }
......
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
#include "paddle/cinn/common/object.h" #include "paddle/cinn/common/object.h"
#include "paddle/cinn/common/shared.h" #include "paddle/cinn/common/shared.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_compare.h" #include "paddle/cinn/ir/utils/ir_compare.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
...@@ -70,8 +70,8 @@ struct SearchStateHash { ...@@ -70,8 +70,8 @@ struct SearchStateHash {
size_t operator()(const SearchState& s) const; size_t operator()(const SearchState& s) const;
}; };
// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST // SearchStateHash equal functor, use ir::ir_utils::IrEqualVisitor to compare
// struct and fields // their AST struct and fields
struct SearchStateEqual { struct SearchStateEqual {
bool operator()(const SearchState& lhs, const SearchState& rhs) const; bool operator()(const SearchState& lhs, const SearchState& rhs) const;
}; };
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h" #include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/context.h" #include "paddle/cinn/common/context.h"
...@@ -35,35 +36,18 @@ TEST(TestSearchState, SearchStateHash_Equal) { ...@@ -35,35 +36,18 @@ TEST(TestSearchState, SearchStateHash_Equal) {
ir::Tensor C = lang::Compute( ir::Tensor C = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C");
ast_gen_ius::TensorGroup const_group_1({A, B});
cinn::common::Context::Global().ResetNameId(); cinn::common::Context::Global().ResetNameId();
auto a_plus_const_funcs_1 = lang::LowerVec("A_plus_const", auto a_plus_const_funcs_1 =
poly::CreateStages({A, B}), lang::LowerToAstVec("A_plus_const", {A, B}, &const_group_1, target);
{A, B},
{},
{},
nullptr,
target,
true);
cinn::common::Context::Global().ResetNameId(); cinn::common::Context::Global().ResetNameId();
auto a_plus_const_funcs_2 = lang::LowerVec("A_plus_const", ast_gen_ius::TensorGroup const_group_2({A, B});
poly::CreateStages({A, B}), auto a_plus_const_funcs_2 =
{A, B}, lang::LowerToAstVec("A_plus_const", {A, B}, &const_group_2, target);
{},
{},
nullptr,
target,
true);
cinn::common::Context::Global().ResetNameId(); cinn::common::Context::Global().ResetNameId();
auto a_plus_b_funcs = lang::LowerVec("A_plus_B", ast_gen_ius::TensorGroup plus_group({A, C});
poly::CreateStages({A, C}), auto a_plus_b_funcs =
{A, C}, lang::LowerToAstVec("A_plus_B", {A, C}, &plus_group, target);
{},
{},
nullptr,
target,
true);
std::string a_plus_const_funcs_1_str = R"ROC(function A_plus_const (_A, _B) std::string a_plus_const_funcs_1_str = R"ROC(function A_plus_const (_A, _B)
{ {
......
...@@ -4,5 +4,6 @@ core_gather_headers() ...@@ -4,5 +4,6 @@ core_gather_headers()
gather_srcs(cinnapi_src SRCS evolutionary_search.cc) gather_srcs(cinnapi_src SRCS evolutionary_search.cc)
cinn_cc_test(test_evolutionary_search SRCS evolutionary_search_test.cc DEPS # TODO(zhhsplendid): enable this test again
cinncore test_program_builder) #cinn_cc_test(test_evolutionary_search SRCS evolutionary_search_test.cc DEPS
# cinncore test_program_builder)
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#include "paddle/cinn/utils/sized_multi_set.h" #include "paddle/cinn/utils/sized_multi_set.h"
#include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/string.h"
DECLARE_bool(auto_schedule_use_cost_model); PD_DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
...@@ -134,7 +134,7 @@ std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase( ...@@ -134,7 +134,7 @@ std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto&& record : records) { for (auto&& record : records) {
ir::IRSchedule ir_sch( ir::IRSchedule ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr), ir::ir_utils::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(&rand_seed_)); utils::ForkRandomState(&rand_seed_));
ir::ScheduleDesc::ReplayWithProto(record.trace, &ir_sch); ir::ScheduleDesc::ReplayWithProto(record.trace, &ir_sch);
results.emplace_back(SearchState(std::move(ir_sch), record.predicted_cost)); results.emplace_back(SearchState(std::move(ir_sch), record.predicted_cost));
...@@ -181,9 +181,9 @@ SearchState EvolutionarySearch::CrossOver(const SearchState& state1, ...@@ -181,9 +181,9 @@ SearchState EvolutionarySearch::CrossOver(const SearchState& state1,
for (size_t i = 0; i < father_exprs.size(); ++i) { for (size_t i = 0; i < father_exprs.size(); ++i) {
if (utils::SampleUniformInt(0, 2, &rand_seed_) == 0) { if (utils::SampleUniformInt(0, 2, &rand_seed_) == 0) {
cross_over_exprs.push_back(optim::IRCopy(father_exprs[i])); cross_over_exprs.push_back(ir::ir_utils::IRCopy(father_exprs[i]));
} else { } else {
cross_over_exprs.push_back(optim::IRCopy(mother_exprs[i])); cross_over_exprs.push_back(ir::ir_utils::IRCopy(mother_exprs[i]));
} }
} }
auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs), auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs),
...@@ -216,12 +216,12 @@ SearchState EvolutionarySearch::Mutate( ...@@ -216,12 +216,12 @@ SearchState EvolutionarySearch::Mutate(
// ir_schedule // ir_schedule
const auto& task_key = tune_task_.serialized_key; const auto& task_key = tune_task_.serialized_key;
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
ir::IRSchedule new_ir_sch( ir::IRSchedule pir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr), ir::ir_utils::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(rand_seed)); utils::ForkRandomState(rand_seed));
new_trace.Replay(&new_ir_sch, true); new_trace.Replay(&pir_sch, true);
ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_); ApplyPostScheduleRules(&pir_sch, post_schedule_rules_);
auto res = SearchState(std::move(new_ir_sch)); auto res = SearchState(std::move(pir_sch));
VLOG(5) << JoinStatesDebugString( VLOG(5) << JoinStatesDebugString(
"EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6)); "EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment