Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -22,7 +22,7 @@
#include <numeric>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "test/cpp/cinn/program_builder.h"
namespace cinn {
......
......@@ -26,10 +26,11 @@
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace auto_schedule {
......@@ -49,7 +50,12 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr);
// Check the schedule block to be inlined is not a reduce tensor.
std::set<ir::Expr> find_store = ir::CollectIRNodesWithoutTensor(
for (const ir::Var& iter_var : sche_block->iter_vars) {
if (iter_var->is_reduce_axis) {
return false;
}
}
std::set<ir::Expr> find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Store>(); });
if (find_store.size() != 1UL) {
return false;
......@@ -69,6 +75,29 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
return false;
}
// the xxx_reduce_init block cannot be inlined.
if (ir::IsReduceInitTensorName(tensor->name)) {
return false;
}
// Skip external calls
std::vector<ir::Expr> consumers =
ir::GetConsumers(sche_block_realize_expr, root);
for (const ir::Expr& consumer : consumers) {
std::set<ir::Expr> find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
consumer.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body,
[&](const ir::Expr* x) {
return x->As<ir::Load>() &&
x->As<ir::Load>()->tensor.as_tensor_ref()->name ==
tensor->name;
});
if (find_load.empty()) {
return false;
}
}
// write_buffers.size() = 1 and read_buffers is empty, means const
// we can inline to consumer
if (sche_block->read_buffers.empty()) {
......@@ -76,17 +105,19 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
}
// Check this schedule block is the only writer of the tensor.
find_store = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::Store>() &&
(x->As<ir::Store>()->tensor).as_tensor_ref()->name == tensor->name;
});
find_store =
ir::ir_utils::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::Store>() &&
(x->As<ir::Store>()->tensor).as_tensor_ref()->name ==
tensor->name;
});
if (find_store.size() != 1UL) {
return false;
}
// Check there is no overlap between the buffers the schedule block reads and
// writes.
std::set<ir::Expr> find_load =
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) {
std::set<ir::Expr> find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) {
return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr;
});
if (!find_load.empty()) {
......
......@@ -63,7 +63,6 @@ class AutoInline : public AutoGenRule {
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private:
void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); // NOLINT
private:
......
......@@ -21,6 +21,7 @@
#include <iostream>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/cinn.h"
......@@ -30,9 +31,9 @@
#include "paddle/cinn/ir/function_base.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h"
......@@ -59,16 +60,13 @@ TEST(AutoInline, SingleLoopInline) {
ir::Tensor C = Compute(
{M}, [&](Var i) { return B(i) + ir::Expr(1.f); }, "C");
poly::StageMap stages = CreateStages({A, B, C});
ast_gen_ius::TensorGroup tensor_group({A, B, C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestAutoInline_SingleLoopInline",
stages,
{A, C},
{},
{},
nullptr,
target,
true);
lang::LowerToAstVec("TestAutoInline_SingleLoopInline",
{A, C},
&tensor_Group,
target);
VLOG(6) << "Expr after lowering:";
VLOG(6) << funcs[0]->body;
......@@ -161,14 +159,14 @@ TEST(AutoInline, AddReluInline) {
"inferdtype");
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
auto op_lowerer =
hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target);
EXPECT_EQ(graph->fusion_groups.size(), 1UL);
std::vector<ir::LoweredFunc> funcs =
op_lowerer->Lower(graph->fusion_groups[0],
/*apply_op_schedule = */ false,
/*apply_group_schedule=*/false);
op_lowerer.Lower(graph->fusion_groups[0],
/*apply_op_schedule = */ false,
/*apply_group_schedule=*/false);
VLOG(6) << "Expr before auto inline: " << funcs[0]->body;
......
......@@ -18,10 +18,10 @@
#include <cstdlib>
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace auto_schedule {
......@@ -56,7 +56,7 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
return false;
};
auto find_target_exprs = ir::CollectIRNodesWithoutTensor(
auto find_target_exprs = ir::ir_utils::CollectIRNodesWithoutTensor(
schedule_block->body,
[&has_reduce_iter, &has_nonserial_loop](const Expr* x) {
return has_reduce_iter(x) || has_nonserial_loop(x);
......
......@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/lang/lower.h"
......@@ -38,9 +39,9 @@ TEST(AutoUnroll, Init) {
#else
Target target = common::DefaultHostTarget();
#endif
auto stages = CreateStages({C});
auto funcs = cinn::lang::LowerVec(
"test_init", stages, {A, B, C}, {}, {}, nullptr, target, true);
ast_gen_ius::TensorGroup tensor_group({C});
auto funcs =
cinn::lang::LowerToAstVec("test_init", {A, B, C}, &tensor_group, target);
auto ast_expr = funcs[0]->body;
ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr}));
......
......@@ -20,8 +20,8 @@
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "test/cpp/cinn/program_builder.h"
namespace cinn {
......
......@@ -29,11 +29,11 @@
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace auto_schedule {
......
......@@ -21,15 +21,16 @@
#include <iostream>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h"
......@@ -106,16 +107,9 @@ TEST(MultiLevelTile, SimpleLoops) {
ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMultiLevelTile_SimpleLoops",
stages,
{C},
{},
{},
nullptr,
target,
true);
ast_gen_ius::TensorGroup tensor_group({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerToAstVec(
"TestMultiLevelTile_SimpleLoops", {C}, &tensor_group, target);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: ";
......@@ -261,7 +255,7 @@ TEST_F(TestMultiLevelTiling, Matmul) {
{
i0, i1 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)))
{
temp_matmul_out__reduce_init[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = 0.00000000f
temp_matmul_out__reduce_init[i0, i1] = 0.00000000f
}
}
}
......@@ -308,10 +302,10 @@ TEST_F(TestMultiLevelTiling, Matmul) {
ScheduleBlock(temp_matmul_out_local_temp_buffer)
{
i0_0, i1_0, i2 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)))
read_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)], _X[i(undefined:undefined), reduce_k(undefined:undefined)], _Y[reduce_k(undefined:undefined), j(undefined:undefined)])
write_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)])
read_buffers(_temp_matmul_out[i0_0(0:32), i1_0(0:32)], _X[i0_0(0:32), i2(0:32)], _Y[i2(0:32), i1_0(0:32)])
write_buffers(_temp_matmul_out[i0_0(0:32), i1_0(0:32)])
{
temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = (temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] + (X_reshape_shared_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2))] * Y_reshape_shared_temp_buffer[((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)), ((32 * j_1) + ((32 * j_2) + j_3))]))
temp_matmul_out_local_temp_buffer[i0_0, i1_0] = (temp_matmul_out_local_temp_buffer[i0_0, i1_0] + (X_reshape_shared_temp_buffer[i0_0, i2] * Y_reshape_shared_temp_buffer[i2, i1_0]))
}
}
}
......@@ -453,7 +447,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
pad_temp_0[i0, i1, i2, i3] = select(((i3 < (1 + 16)) and ((i3 >= 1) and ((i2 < (1 + 16)) and (i2 >= 1)))), input[i0, i1, (i2 - 1), (i3 - 1)], -3.40282347e+38f)
}
}
}
......@@ -477,7 +471,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
var_0__reduce_init[i0_0, i1_0, i2_0, i3_0] = -3.40282347e+38f
}
}
}
......@@ -511,10 +505,10 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
ScheduleBlock(var_0_local_temp_buffer)
{
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
read_buffers(_var_0[i0_1(0:2), i1_1(0:8), i2_1(0:8), i3_1(0:8)], _pad_temp_0[i0_1(0:2), i1_1(0:8)])
write_buffers(_var_0[i0_1(0:2), i1_1(0:8), i2_1(0:8), i3_1(0:8)])
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
var_0_local_temp_buffer[i0_1, i1_1, i2_1, i3_1] = cinn_max(var_0_local_temp_buffer[i0_1, i1_1, i2_1, i3_1], pad_temp_0_shared_temp_buffer[i0_1, i1_1, ((2 * i2_1) + i4), ((2 * i3_1) + i5)])
}
}
}
......@@ -533,7 +527,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
ScheduleBlock(var_0)
{
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((i_0_j_0_k_0_a_0_fused % 4) + (4 * ((i_j_k_a_fused / 2) % 2))) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h"
#include <glog/logging.h>
#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
namespace cinn {
namespace auto_schedule {
bool ReductionFactoring::CanApply(const std::string& block_name,
ir::IRSchedule* ir_schedule) const {
ir::Expr block_expr = ir_schedule->GetBlock(block_name);
ir::ScheduleBlockRealize* block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(block_realize);
ir::ScheduleBlock* sch_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(sch_block);
AnalyzeScheduleBlockReadWriteBuffer(sch_block);
// 1. The block must have write buffer
if (sch_block->write_buffers.empty()) {
return false;
}
// 2. The block must have at least one reduce axis
const std::vector<ir::Var>& iter_vars = sch_block->iter_vars;
bool find_reduce_axis = false;
for (int i = 0; i < iter_vars.size(); ++i) {
if (iter_vars[i]->is_reduce_axis) {
find_reduce_axis = true;
break;
}
}
if (!find_reduce_axis) {
return false;
}
// 3. Each loop's body only contains one sub loop or block, except reduce_init
// block
std::vector<ir::Expr> loops = ir_schedule->GetLoops(block_name);
for (const ir::Expr& loop : loops) {
const ir::Expr& body = loop.As<ir::For>()->body;
if (body.As<ir::Block>()) {
if (body.As<ir::Block>()->stmts.size() == 1) {
if (body.As<ir::Block>()->stmts[0].As<ir::For>() == nullptr &&
body.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>() ==
nullptr) {
return false;
}
} else if (body.As<ir::Block>()->stmts.size() == 2) {
if (body.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>() ==
nullptr ||
!ir::IsReduceInitTensorName(
GetBlockName(body.As<ir::Block>()->stmts[0]))) {
return false;
}
if (body.As<ir::Block>()->stmts[1].As<ir::For>() == nullptr &&
body.As<ir::Block>()->stmts[1].As<ir::ScheduleBlockRealize>() ==
nullptr) {
return false;
}
} else {
return false;
}
} else if (body.As<ir::For>() || body.As<ir::ScheduleBlockRealize>()) {
continue;
} else {
return false;
}
}
return true;
}
RuleApplyType ReductionFactoring::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
return this->CanApply(block_name, &(state->ir_schedule))
? RuleApplyType::kApply
: RuleApplyType::kCannotApply;
}
std::vector<SearchState> ReductionFactoring::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy();
Apply(block_name, &(new_state->ir_schedule));
return {new_state};
}
void ReductionFactoring::Apply(const std::string& block_name,
ir::IRSchedule* ir_schedule) {
ir::Expr block = ir_schedule->GetBlock(block_name);
std::vector<ir::Expr> all_loops = ir_schedule->GetLoops(block_name);
std::vector<ir::Expr> new_loop_order;
size_t num_spatial_loops = 0;
size_t num_reduction_loops = 0;
// 1. Add all spatial loops
std::unordered_set<std::string> reduce_loop_var_names =
GetReduceLoopVarNames(block);
for (const ir::Expr& expr : all_loops) {
if (reduce_loop_var_names.count(expr.As<ir::For>()->loop_var->name) == 0) {
new_loop_order.push_back(expr);
++num_spatial_loops;
}
}
// 2. Add all reduction loops
for (const ir::Expr& expr : all_loops) {
if (reduce_loop_var_names.count(expr.As<ir::For>()->loop_var->name) > 0) {
new_loop_order.push_back(expr);
++num_reduction_loops;
}
}
if (num_reduction_loops == 0) {
return;
}
// 3. Reorder if new_loop_order differs from the original order
CHECK_EQ(all_loops.size(), new_loop_order.size());
for (int i = 0; i < all_loops.size(); ++i) {
if (all_loops[i].As<ir::For>()->loop_var->name !=
new_loop_order[i].As<ir::For>()->loop_var->name) {
ir_schedule->Reorder(new_loop_order);
break;
}
}
// 4. Fuse all reduction loops
ir::Expr fused_reduce_loop;
VLOG(6) << "before Fuse: " << ir_schedule->GetModule().GetExprs()[0];
if (num_reduction_loops > 1) {
std::vector<int> reduction_loop_indices;
for (int i = num_spatial_loops; i < all_loops.size(); ++i) {
reduction_loop_indices.push_back(i);
}
CHECK_EQ(reduction_loop_indices.size(), num_reduction_loops);
fused_reduce_loop = ir_schedule->Fuse(block_name, reduction_loop_indices);
} else {
all_loops = ir_schedule->GetLoops(block_name);
fused_reduce_loop = all_loops.back();
}
// 5. Split the reduction loop into 2 part
VLOG(6) << "before Split: " << ir_schedule->GetModule().GetExprs()[0];
int factor = 1;
int max_factor = 1024;
int extent = ir::GetLoopExtent(fused_reduce_loop);
for (int i = max_factor; i >= 1; --i) {
if (extent % i == 0) {
factor = i;
break;
}
}
std::vector<cinn::ir::Expr> splited_reduction_loops =
ir_schedule->Split(fused_reduce_loop, {factor, -1});
// 6. Apply FactorizeReduction
VLOG(6) << "before FactorizeReduction: "
<< ir_schedule->GetModule().GetExprs()[0];
ir_schedule->FactorizeReduction(splited_reduction_loops[0],
num_spatial_loops);
VLOG(6) << "after FactorizeReduction: "
<< ir_schedule->GetModule().GetExprs()[0];
// 7. Loop fusion and cross thread reduction
std::vector<ir::Expr> rb_loops = ir_schedule->GetLoops(block_name);
ir::Expr rf_block = ir_schedule->GetBlock(block_name + "_rf");
ir_schedule->SimpleComputeAt(rf_block, rb_loops.back());
rb_loops = ir_schedule->GetLoops(block_name);
ir::Expr rf_init_block =
ir_schedule->GetBlock(block_name + "_rf__reduce_init");
ir_schedule->SimpleComputeAt(rf_init_block, rb_loops.back());
if (*target_ == common::DefaultNVGPUTarget()) {
rb_loops = ir_schedule->GetLoops(block_name);
rf_block = ir_schedule->GetBlock(block_name + "_rf");
ir_schedule->Bind(rb_loops.back(), "threadIdx.x");
ir_schedule->SetBuffer(rf_block, "shared");
}
VLOG(6) << "Loop fusion and cross thread reduction: "
<< ir_schedule->GetModule().GetExprs()[0];
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
class ReductionFactoring : public AutoGenRule {
public:
explicit ReductionFactoring(const common::Target& target)
: AutoGenRule(target) {}
~ReductionFactoring() = default;
// In the future, we will no longer use this interface.
RuleApplyType Init(ir::IRSchedule* init_schedule) override {
return RuleApplyType::kCannotApply;
}
// In the future, we will no longer use this interface.
void Apply(int index) override {
LOG(FATAL) << "This is a deprecated interface, please do not use it.";
return;
}
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::string GetRuleName() const override { return "ReductionFactoring"; }
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
void Apply(const std::string& block_name, ir::IRSchedule* ir_schedule);
private:
bool CanApply(const std::string& block_name,
ir::IRSchedule* ir_schedule) const;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cmath>
#include <functional>
#include <numeric>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "test/cpp/cinn/concrete_program_builder.h"
PD_DECLARE_bool(cinn_new_group_scheduler);
namespace cinn {
namespace auto_schedule {
class TestReductionFactoring : public TestAutoGenRuleBase {
public:
std::vector<std::string> default_input_names = {"X"};
std::vector<std::string> default_output_names = {"out"};
void TestApplyOnReduce(const std::vector<int>& shape,
const std::vector<int>& reduce_dim,
const std::string& block_name,
const std::string& expected_ir) {
Initialize(common::DefaultNVGPUTarget());
// In order to forcibly use the most basic Compute of reduction
FLAGS_cinn_new_group_scheduler = 1;
auto test_program = tests::ReduceBuilder().Build(
{{"X", shape}}, {{"reduce_dim", reduce_dim}});
// construct input parameter
ir::IRSchedule ir_schedule = MakeIRSchedule(test_program);
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];
// apply
ReductionFactoring reduction_factoring(target_);
ASSERT_EQ(reduction_factoring.AnalyseApplyType(state, block_name),
RuleApplyType::kApply);
auto result = reduction_factoring.ApplyOnBlock(state, block_name)[0];
std::vector<ir::Expr> exprs = result->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ir;
ir << exprs[0];
VLOG(6) << "ReductionFactoring applied Expr: " << exprs[0];
// check
const std::vector<ir::Expr>& blocks = ir_schedule.GetAllBlocks();
CHECK_EQ(blocks.size(), 2UL);
CHECK_EQ(ir.str(), expected_ir);
}
};
TEST_F(TestReductionFactoring, AnalyseApplyType) {
Context::Global().ResetNameId();
Initialize(common::DefaultNVGPUTarget());
auto test_program =
tests::OpBuilder("elementwise_add").Build({{"X", {4, 5}}, {"Y", {4, 5}}});
ir::IRSchedule ir_schedule = MakeIRSchedule(test_program);
VLOG(6) << "Original Expr:\n" << ir_schedule.GetModule().GetExprs()[0];
SearchState state(ir_schedule, 0, {});
ReductionFactoring reduction_factoring(target_);
EXPECT_EQ(reduction_factoring.AnalyseApplyType(state, "var_1"),
RuleApplyType::kCannotApply);
}
#ifdef CINN_WITH_CUDA
TEST_F(TestReductionFactoring, ApplyOnBlock1ReduceDim) {
Context::Global().ResetNameId();
std::string expected_ir = R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_0, 0, 64)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_0] = 0.00000000f
}
{
serial for (reduce_k_0_1, 0, 1)
{
ScheduleBlock(var_0_rf)
{
vreduce_k_0_0, i0_0, vreduce_k_0_1 = axis.bind(reduce_k_0_0, i, reduce_k_0_1)
var_0_rf[i0_0, vreduce_k_0_0] = (var_0_rf[i0_0, vreduce_k_0_0] + X[i0_0, (vreduce_k_0_0 + vreduce_k_0_1)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_0])
}
}
}
}
}
}
}
})";
TestApplyOnReduce({32, 64}, {1}, "var_0", expected_ir);
}
TEST_F(TestReductionFactoring, ApplyOnBlock2ReduceDim) {
Context::Global().ResetNameId();
std::string expected_ir = R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_fused, 0, 1024)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_fused] = 0.00000000f
}
{
serial for (reduce_k_0_reduce_k_1_fused_0, 0, 8)
{
ScheduleBlock(var_0_rf)
{
vreduce_k_0_reduce_k_1_fused, i0_0, vreduce_k_0_reduce_k_1_fused_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i, reduce_k_0_reduce_k_1_fused_0)
var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] + X[i0_0, (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) / 128), (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) % 128)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused])
}
}
}
}
}
}
}
})";
TestApplyOnReduce({32, 64, 128}, {1, 2}, "var_0", expected_ir);
}
TEST_F(TestReductionFactoring, ApplyOnBlock3ReduceDim) {
Context::Global().ResetNameId();
std::string expected_ir = R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 1024)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = 0.00000000f
}
{
serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused_0, 0, 256)
{
ScheduleBlock(var_0_rf)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i, reduce_k_0_reduce_k_1_reduce_k_2_fused_0)
var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] + X[i0_0, ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) / 64), ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) % 64), (((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) % 64)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused])
}
}
}
}
}
}
}
})";
TestApplyOnReduce({32, 64, 64, 64}, {1, 2, 3}, "var_0", expected_ir);
}
#endif
} // namespace auto_schedule
} // namespace cinn
......@@ -21,6 +21,7 @@
#include <iostream>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
......@@ -52,9 +53,9 @@ TEST(SkipRule, Basic) {
ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
ast_gen_ius::TensorGroup tensor_group({C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerToAstVec("TestSkipRule_Basic", {C}, &tensor_group, target);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: ";
......@@ -101,9 +102,9 @@ TEST(SkipRule, ApplyOnSpecificBlock) {
ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
ast_gen_ius::TensorGroup tensor_group({C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerToAstVec("TestSkipRule_Basic", {C}, &tensor_group, target);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: ";
......
......@@ -61,12 +61,14 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
"inferdtype");
auto& shape_dict = graph->GetMutableAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_);
auto op_lowerer =
hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target_);
lowered_funcs_ =
op_lowerer.Lower(graph->fusion_groups.front(),
/*apply_op_schedule = */ apply_manual_schedule,
/*apply_group_schedule = */ apply_manual_schedule);
/*apply_group_schedule = */ apply_manual_schedule,
/*apply_pass = */ apply_manual_schedule);
CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty";
std::vector<Expr> bodys;
......
......@@ -34,7 +34,7 @@
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(auto_schedule_use_cost_model);
PD_DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn {
namespace auto_schedule {
......
......@@ -20,9 +20,9 @@
#include <vector>
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/utils/functional.h"
#include "paddle/cinn/utils/string.h"
......@@ -133,11 +133,10 @@ bool SearchStateEqual::operator()(const SearchState& lhs,
// compare exprs size firstly
if (lhs_exprs.size() != rhs_exprs.size()) return false;
// compare every expr one by one with ir::IrEqualVisitor
// compare every expr one by one with ir::ir_utils::IrEqualVisitor
for (int i = 0; i < lhs_exprs.size(); ++i) {
ir::IrEqualVisitor compartor(
/*allow_name_suffix_diff=*/true); // ignore suffix difference in name
if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false;
if (!ir::ir_utils::IRCompare(lhs_exprs[i], rhs_exprs[i], true))
return false;
}
return true;
}
......
......@@ -20,9 +20,9 @@
#include "paddle/cinn/common/object.h"
#include "paddle/cinn/common/shared.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_compare.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace cinn {
namespace auto_schedule {
......@@ -70,8 +70,8 @@ struct SearchStateHash {
size_t operator()(const SearchState& s) const;
};
// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST
// struct and fields
// SearchStateHash equal functor, use ir::ir_utils::IrEqualVisitor to compare
// their AST struct and fields
struct SearchStateEqual {
bool operator()(const SearchState& lhs, const SearchState& rhs) const;
};
......
......@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/context.h"
......@@ -35,35 +36,18 @@ TEST(TestSearchState, SearchStateHash_Equal) {
ir::Tensor C = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C");
ast_gen_ius::TensorGroup const_group_1({A, B});
cinn::common::Context::Global().ResetNameId();
auto a_plus_const_funcs_1 = lang::LowerVec("A_plus_const",
poly::CreateStages({A, B}),
{A, B},
{},
{},
nullptr,
target,
true);
auto a_plus_const_funcs_1 =
lang::LowerToAstVec("A_plus_const", {A, B}, &const_group_1, target);
cinn::common::Context::Global().ResetNameId();
auto a_plus_const_funcs_2 = lang::LowerVec("A_plus_const",
poly::CreateStages({A, B}),
{A, B},
{},
{},
nullptr,
target,
true);
ast_gen_ius::TensorGroup const_group_2({A, B});
auto a_plus_const_funcs_2 =
lang::LowerToAstVec("A_plus_const", {A, B}, &const_group_2, target);
cinn::common::Context::Global().ResetNameId();
auto a_plus_b_funcs = lang::LowerVec("A_plus_B",
poly::CreateStages({A, C}),
{A, C},
{},
{},
nullptr,
target,
true);
ast_gen_ius::TensorGroup plus_group({A, C});
auto a_plus_b_funcs =
lang::LowerToAstVec("A_plus_B", {A, C}, &plus_group, target);
std::string a_plus_const_funcs_1_str = R"ROC(function A_plus_const (_A, _B)
{
......
......@@ -4,5 +4,6 @@ core_gather_headers()
gather_srcs(cinnapi_src SRCS evolutionary_search.cc)
cinn_cc_test(test_evolutionary_search SRCS evolutionary_search_test.cc DEPS
cinncore test_program_builder)
# TODO(zhhsplendid): enable this test again
#cinn_cc_test(test_evolutionary_search SRCS evolutionary_search_test.cc DEPS
# cinncore test_program_builder)
......@@ -36,7 +36,7 @@
#include "paddle/cinn/utils/sized_multi_set.h"
#include "paddle/cinn/utils/string.h"
DECLARE_bool(auto_schedule_use_cost_model);
PD_DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn {
namespace auto_schedule {
......@@ -134,7 +134,7 @@ std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto&& record : records) {
ir::IRSchedule ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
ir::ir_utils::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(&rand_seed_));
ir::ScheduleDesc::ReplayWithProto(record.trace, &ir_sch);
results.emplace_back(SearchState(std::move(ir_sch), record.predicted_cost));
......@@ -181,9 +181,9 @@ SearchState EvolutionarySearch::CrossOver(const SearchState& state1,
for (size_t i = 0; i < father_exprs.size(); ++i) {
if (utils::SampleUniformInt(0, 2, &rand_seed_) == 0) {
cross_over_exprs.push_back(optim::IRCopy(father_exprs[i]));
cross_over_exprs.push_back(ir::ir_utils::IRCopy(father_exprs[i]));
} else {
cross_over_exprs.push_back(optim::IRCopy(mother_exprs[i]));
cross_over_exprs.push_back(ir::ir_utils::IRCopy(mother_exprs[i]));
}
}
auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs),
......@@ -216,12 +216,12 @@ SearchState EvolutionarySearch::Mutate(
// ir_schedule
const auto& task_key = tune_task_.serialized_key;
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
ir::IRSchedule new_ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
ir::IRSchedule pir_sch(
ir::ir_utils::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(rand_seed));
new_trace.Replay(&new_ir_sch, true);
ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_);
auto res = SearchState(std::move(new_ir_sch));
new_trace.Replay(&pir_sch, true);
ApplyPostScheduleRules(&pir_sch, post_schedule_rules_);
auto res = SearchState(std::move(pir_sch));
VLOG(5) << JoinStatesDebugString(
"EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment