Commit 992bec46 authored by “yuguo”'s avatar “yuguo”
Browse files

2.5

parent 0259837d
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h"
#include <memory>
#include <set>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace auto_schedule {
AutoInline::AutoInline(
const common::Target& target,
const std::unordered_set<std::string>& no_inline_output_names)
: AutoGenRule(target), no_inline_output_names_(no_inline_output_names) {}
bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir::IRSchedule* ir_sch) const {
const ir::ScheduleBlockRealize* sche_block_realize =
sche_block_realize_expr.As<ir::ScheduleBlockRealize>();
const ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
ir::Expr compute_body = sche_block->body;
ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr);
// Check the schedule block to be inlined is not a reduce tensor.
std::set<ir::Expr> find_store = ir::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Store>(); });
if (find_store.size() != 1UL) {
return false;
}
ir::Expr tensor_expr = (*find_store.begin()).As<ir::Store>()->tensor;
ir::Tensor tensor = tensor_expr.as_tensor_ref();
if (tensor->is_reduce_tensor()) {
return false;
}
// LoweredFunc output can be tensor name or tensor buffer name
if (no_inline_output_names_.find(tensor->name) !=
no_inline_output_names_.end() ||
no_inline_output_names_.find(tensor->buffer->name) !=
no_inline_output_names_.end()) {
return false;
}
// write_buffers.size() = 1 and read_buffers is empty, means const
// we can inline to consumer
if (sche_block->read_buffers.empty()) {
return true;
}
// Check this schedule block is the only writer of the tensor.
find_store = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::Store>() &&
(x->As<ir::Store>()->tensor).as_tensor_ref()->name == tensor->name;
});
if (find_store.size() != 1UL) {
return false;
}
// Check there is no overlap between the buffers the schedule block reads and
// writes.
std::set<ir::Expr> find_load =
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) {
return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr;
});
if (!find_load.empty()) {
return false;
}
ir::Expr store = *(find_store.begin());
ir::ComputeInliner inliner(store.As<ir::Store>()->tensor.as_tensor_ref(),
store);
if (!inliner.BodyPatternAllowInline()) {
return false;
}
ir::LeafBlockRemovalPlan remove_plan(
sche_block_realize_expr, &inliner.src_stmt, &inliner.tgt_stmt);
remove_plan(&root);
if (!inliner.src_stmt.defined() || !inliner.tgt_stmt.defined()) {
return false;
}
VLOG(6) << "Found store Expr " << store << ", which CanInlineIntoConsumer";
return true;
}
AutoInlineType AutoInline::AnalyzeInlineType(
const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const {
const ir::ScheduleBlockRealize* sche_block_realize =
sche_block_realize_expr.As<ir::ScheduleBlockRealize>();
const ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
// Inline if the block has only 1 write buffer
if (sche_block->write_buffers.size() != 1) {
return AutoInlineType::kCannotInline;
}
std::unordered_set<ir::IrNodeTy> no_inline_node_types = {
ir::IrNodeTy::IfThenElse};
if (ContainsNodeType(sche_block->body, no_inline_node_types)) {
return AutoInlineType::kCannotInline;
}
// InlineIntoConsumer other than above situations
if (CanInlineIntoConsumer(sche_block_realize_expr, ir_sch)) {
return AutoInlineType::kInlineIntoConsumer;
}
// TODO(zhhsplendid): We don't have ReverseComputeInline in IRSchedule now,
// so we just do kInlineIntoConsumer here. Add CanInlineIntoProducer
// once ReverseComputeInline is ready.
return AutoInlineType::kCannotInline;
}
RuleApplyType AutoInline::Init(ir::IRSchedule* ir_schedule) {
ir_schedule_ = ir_schedule;
all_block_realizes_ = ir_schedule_->GetAllBlocks();
apply_indices_and_type_.clear();
num_applicable_ = 0;
for (size_t i = 0; i < all_block_realizes_.size(); ++i) {
ir::ScheduleBlockRealize* sche_block_realize =
all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
AnalyzeScheduleBlockReadWriteBuffer(
sche_block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type =
AnalyzeInlineType(all_block_realizes_[i], ir_schedule_);
if (type != AutoInlineType::kCannotInline) {
++num_applicable_;
apply_indices_and_type_.push_back({i, type});
}
}
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
void AutoInline::Apply(int index) {
CHECK(ir_schedule_ != nullptr) << "Run AutoInline::Apply without Init";
CHECK(num_applicable_ > 0 &&
apply_indices_and_type_.size() == num_applicable_)
<< "AutoInline::Apply pre-condition doesn't meet";
CHECK(index >= 0 && num_applicable_ > index)
<< "Invalid index for AutoInline::Apply, the index needs 0 <= index && "
"index < NumberApplicable(), "
<< "Currently index = " << index
<< ", NumberApplicable() = " << num_applicable_;
int apply_index = apply_indices_and_type_[index].first;
Apply(ir_schedule_, all_block_realizes_[apply_index]);
return;
}
std::string AutoInline::GetRuleName() const { return "AutoInline"; }
RuleApplyType AutoInline::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name);
auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr;
AnalyzeScheduleBlockReadWriteBuffer(
block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type = AnalyzeInlineType(block_expr, &state->ir_schedule);
return type == AutoInlineType::kCannotInline
? RuleApplyType::kCannotApply
: RuleApplyType::kApplyAndPruneOtherRules;
}
std::vector<SearchState> AutoInline::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy();
Expr block_expr = new_state->ir_schedule.GetBlock(block_name);
Apply(&new_state->ir_schedule, block_expr);
return {new_state};
}
void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) {
auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr;
AnalyzeScheduleBlockReadWriteBuffer(
block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type = AnalyzeInlineType(block_expr, ir_schedule);
if (type == AutoInlineType::kInlineIntoConsumer) {
VLOG(6) << "Apply ComputeInline on " << block_expr;
ir_schedule->ComputeInline(block_expr);
VLOG(6) << "After ComputeInline: " << block_expr;
} else if (type == AutoInlineType::kInlineIntoProducer) {
// TODO(zhhsplendid): We don't have ReverseComputeInline in IRSchedule now,
// so we just do kInlineIntoConsumer here. Add CanInlineIntoConsumer
// once ReverseComputeInline is ready.
// ir_schedule->ReverseComputeInline(all_block_realizes_[apply_index]);
}
// Make sure re-apply the AutoInline won't be error.
// AutoInline changes the read and write buffers of schedule blocks,
// we need to re-analyze
all_block_realizes_ = ir_schedule->GetAllBlocks();
for (size_t i = 0; i < all_block_realizes_.size(); ++i) {
ir::ScheduleBlockRealize* sche_block_realize =
all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
sche_block->read_buffers = {};
sche_block->write_buffers = {};
AnalyzeScheduleBlockReadWriteBuffer(sche_block);
}
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
/**
* The types of the AutoInline
*/
enum class AutoInlineType : int {
// The block cannot be inlined
kCannotInline = 0,
// Inline this block into the consumer
kInlineIntoConsumer,
// Inline this block into the producer
kInlineIntoProducer,
};
class AutoInline : public AutoGenRule {
public:
AutoInline(const common::Target& target,
const std::unordered_set<std::string>& no_inline_output_names);
~AutoInline() = default;
RuleApplyType Init(ir::IRSchedule* ir_schedule) override;
void Apply(int index) override;
std::string GetRuleName() const override;
AutoInlineType AnalyzeInlineType(const Expr& sche_block_realize_expr,
ir::IRSchedule* ir_sch) const;
bool CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir::IRSchedule* ir_sch) const;
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private:
void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); // NOLINT
private:
std::vector<ir::Expr> all_block_realizes_;
std::vector<std::pair<int, AutoInlineType>> apply_indices_and_type_;
std::unordered_set<std::string> no_inline_output_names_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstdlib>
#include <iostream>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/ir/function_base.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
#include "test/cpp/cinn/concrete_program_builder.h"
namespace cinn {
namespace auto_schedule {
using ::cinn::hlir::framework::Graph;
using ::cinn::hlir::framework::OpLowerer;
TEST(AutoInline, SingleLoopInline) {
srand(0);
Context::Global().ResetNameId();
Target target = common::DefaultHostTarget();
Expr M(32);
Placeholder<float> A("A", {M});
ir::Tensor B = Compute(
{M}, [&](Var i) { return A(i) * ir::Expr(2.f); }, "B");
ir::Tensor C = Compute(
{M}, [&](Var i) { return B(i) + ir::Expr(1.f); }, "C");
poly::StageMap stages = CreateStages({A, B, C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestAutoInline_SingleLoopInline",
stages,
{A, C},
{},
{},
nullptr,
target,
true);
VLOG(6) << "Expr after lowering:";
VLOG(6) << funcs[0]->body;
/*
* We have to use ComputeAt to put two Tensor loops together to create IR
* test case for AutoInline.
*/
ir::IRSchedule ir_sch(ir::ModuleExpr(std::vector<ir::Expr>{funcs[0]->body}));
SearchState state(ir_sch, 0, {});
ir::Expr block_b = ir_sch.GetBlock("B");
std::vector<ir::Expr> loops = ir_sch.GetLoops("C");
ir_sch.ComputeAt(block_b, loops[0]);
ir::ModuleExpr mod_expr_before_inline = ir_sch.GetModule();
VLOG(6) << "Expr after ComputeAt:";
VLOG(6) << mod_expr_before_inline.GetExprs()[0];
AutoInline auto_inline(target, {"C"});
EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(auto_inline.NumberApplicable(), 1);
auto_inline.ApplyRandomly();
std::vector<ir::Expr> exprs = ir_sch.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
// ApplyOnBlock
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "B"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "B");
auto test_func = [](ir::IRSchedule* ir_sch) {
ir::ModuleExpr mod_expr_after_inline = ir_sch->GetModule();
std::vector<ir::Expr> exprs = mod_expr_after_inline.GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ss;
ss << exprs[0];
std::string expr_str = ss.str();
VLOG(6) << "After AutoInline:";
VLOG(6) << expr_str;
std::string target_str = R"ROC({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
ScheduleBlock(C)
{
i0 = axis.bind(i)
read_buffers(_A[i0(0:32)])
write_buffers(_C[i0(0:32)])
C[i0] = ((A[i0] * 2.00000000f) + 1.00000000f)
}
}
}
}
})ROC";
EXPECT_EQ(expr_str, target_str);
};
test_func(&ir_sch);
test_func(&new_states[0]->ir_schedule);
// Cannot inline above expr again
EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply);
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "C"),
RuleApplyType::kCannotApply);
}
TEST(AutoInline, AddReluInline) {
srand(0);
Context::Global().ResetNameId();
Target target = common::DefaultHostTarget();
frontend::NetBuilder builder("test");
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
frontend::Program program = builder.Build();
auto graph = std::make_shared<Graph>(program, target);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
const auto& dtype_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
EXPECT_EQ(graph->fusion_groups.size(), 1UL);
std::vector<ir::LoweredFunc> funcs =
op_lowerer->Lower(graph->fusion_groups[0],
/*apply_op_schedule = */ false,
/*apply_group_schedule=*/false);
VLOG(6) << "Expr before auto inline: " << funcs[0]->body;
ir::ModuleExpr mod_expr_before_inline(std::vector<Expr>({funcs[0]->body}));
ir::IRSchedule ir_sch(mod_expr_before_inline);
SearchState state(ir_sch, 0, {});
AutoInline auto_inline(target, {"var_2"});
EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(auto_inline.NumberApplicable(), 2);
auto_inline.Apply(1);
ir::ModuleExpr mod_expr_after_inline = ir_sch.GetModule();
std::vector<ir::Expr> exprs = mod_expr_after_inline.GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ss;
ss << exprs[0];
std::string expr_str = ss.str();
VLOG(6) << "After AutoInline:";
VLOG(6) << expr_str;
// Auto Inline again
EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(auto_inline.NumberApplicable(), 1);
auto_inline.Apply(0);
// ApplyOnBlock
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_1"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "var_1");
// Auto Inline again
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_3"),
RuleApplyType::kApplyAndPruneOtherRules);
new_states = auto_inline.ApplyOnBlock(new_states[0], "var_3");
auto test_func = [](ir::IRSchedule* ir_sch) {
ir::ModuleExpr final_mod_expr = ir_sch->GetModule();
auto exprs = final_mod_expr.GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ss;
ss << exprs[0];
std::string expr_str = ss.str();
VLOG(6) << "Final AutoInline:";
VLOG(6) << expr_str;
std::string target_str = R"ROC({
ScheduleBlock(root)
{
{
serial for (i, 0, 1)
{
serial for (j, 0, 64)
{
serial for (k, 0, 112)
{
serial for (a, 0, 112)
{
ScheduleBlock(var_2)
{
i0, i1, i2, i3 = axis.bind(0, j, k, a)
read_buffers(_A[i0(0:1), i1(0:64), i2(0:112), i3(0:112)], _B[i1(0:64)])
write_buffers(_var_2[i0(0:1), i1(0:64), i2(0:112), i3(0:112)])
var_2[i0, i1, i2, i3] = cinn_max((A[i0, i1, i2, i3] + B[i1]), 0.00000000f)
}
}
}
}
}
}
}
})ROC";
EXPECT_EQ(expr_str, target_str);
};
test_func(&ir_sch);
test_func(&new_states[0]->ir_schedule);
// Cannot inline above expr again
EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply);
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_2"),
RuleApplyType::kCannotApply);
}
#ifdef CINN_WITH_CUDA
class TestAutoInline : public TestAutoGenRuleBase {};
/* The single chain graph composed of multiple blocks can be inlined into one.
*
* Before AutoInline: The output of the previous block is the input of another
* block. Loop1: x1 = Add() Loop2: x2 = Multiply(x1) Loop3: x3 = Add(x2) Loop4:
* x4 = Relu(x3)
*
* After AutoInline: All loops are inlined into a loop.
* Loop:
* Add(Multiply(Add(Relu())))
*/
TEST_F(TestAutoInline, SingleChain) {
Target target = common::DefaultNVGPUTarget();
Initialize(target);
std::vector<std::string> input_names = {
"bias", "conv_output", "bn_scale", "bn_offset"};
std::vector<std::string> output_names = {
"var_6", "var_5", "var_1", "var", "var_0", "var_4", "var_3"};
std::vector<int32_t> conv_output_shape = {1, 512, 56, 56};
int32_t channel = conv_output_shape[1];
std::vector<tests::VariableInfo> inputs_varinfo(
{{"conv_output", conv_output_shape},
{"bias", {channel, 1, 1}},
{"bn_scale", {channel, 1, 1}},
{"bn_offset", {channel, 1, 1}}});
// Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId();
ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];
// Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_3"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "var_3");
std::vector<std::string> inline_block_names(
{"var_4", "var_5", "var_6", "var", "var_0", "var_1"});
for (const auto& inline_block_name : inline_block_names) {
new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name);
}
std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = BuildIRModule(MakeIRSchedule(
tests::BiasBnReLUBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually);
VLOG(6) << " manually-schedule source code:\n" << source_code_manually;
CheckResult(GenExecutableKernel(build_module_auto),
GenExecutableKernel(build_module_manually),
input_names,
output_names,
{{conv_output_shape[1], 1, 1},
conv_output_shape,
conv_output_shape,
conv_output_shape},
{conv_output_shape, {1}, {1}, {1}, {1}, {1}, {1}},
target);
}
/* An op can be inlined into multiple consumers at the same time.
*
* Before AutoInline: The output of Exp is used by Add and Multiply.
* Loop1:
* x = Exp()
* Loop2:
* y = Add(x)
* Loop3:
* z = Multiply(x)
*
* After AutoInline: Exp is inlined into Add and Multiply.
* Loop:
* y = Add(Exp())
* z = Multiply(Exp())
*/
TEST_F(TestAutoInline, InlineToMultiConsumers) {
Target target = common::DefaultNVGPUTarget();
Initialize(target);
std::vector<std::string> input_names = {"x"};
std::vector<std::string> output_names = {"var_2", "var_1", "var_0"};
std::vector<int32_t> input_shape{256, 256};
std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}});
// Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId();
ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];
// Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_0"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "var_1");
new_states = auto_inline.ApplyOnBlock(state, "var_0");
std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = BuildIRModule(MakeIRSchedule(
tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually);
VLOG(6) << " manually-schedule source code:\n" << source_code_manually;
CheckResult(GenExecutableKernel(build_module_auto),
GenExecutableKernel(build_module_manually),
input_names,
output_names,
{input_shape},
{input_shape, {1}, {1}},
target);
}
/* Operators of type elementwise or injective can all be inlined.
*
* Before AutoInline: A graph of Gather, Add and Subtract
* Loop1:
* x1 = Gather()
* Loop2:
* x2 = Add(x1)
* Loop3:
* y1 = Gather()
* Loop4:
* z1 = Subtract(y1, x1)
*
* After AutoInline: All loops are inlined to one
* z1 = Subtract(Gather(), Add(Gather()))
*/
TEST_F(TestAutoInline, OnlySpatialOp) {
Target target = common::DefaultNVGPUTarget();
Initialize(target);
std::vector<std::string> input_names = {"x", "y"};
std::vector<std::string> output_names = {"var_6",
"var_4",
"constant_idx_last",
"constant_idx_first",
"var_2",
"var_5"};
std::vector<int32_t> input_shape{256, 256};
std::vector<tests::VariableInfo> inputs_varinfo(
{{"x", input_shape}, {"y", input_shape}});
// Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId();
ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];
// Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "constant_idx_first"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "constant_idx_first");
std::vector<std::string> inline_block_names(
{"constant_idx_last", "var_2", "var_5", "var_4"});
for (const auto& inline_block_name : inline_block_names) {
new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name);
}
std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = BuildIRModule(MakeIRSchedule(
tests::GatherAddSubBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually);
VLOG(6) << " manually-schedule source code:\n" << source_code_manually;
CheckResult(GenExecutableKernel(build_module_auto),
GenExecutableKernel(build_module_manually),
input_names,
output_names,
{input_shape, input_shape},
{input_shape, {1}, {1}, {1}, {1}, {1}},
target);
}
/* An op that does not read data can be directly inlined.
*
* Before AutoInline: fill_constant op is in a separate loop.
* Loop1:
* x = fill_constant()
* Loop2:
* y = Add(x)
*
* After AutoInline: fill_constant op is inlined into other loop
* Loop:
* y = Add(fill_constant())
*/
TEST_F(TestAutoInline, NoReadBufferOp) {
Target target = common::DefaultNVGPUTarget();
Initialize(target);
std::vector<std::string> input_names = {"x"};
std::vector<std::string> output_names = {"var_0", "fill_constant"};
std::vector<int32_t> input_shape{256, 256};
std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}});
// Construct the computation graph and convert it to ir::Expr
ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];
// Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "fill_constant"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "fill_constant");
std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = BuildIRModule(MakeIRSchedule(
tests::FillConstantAddBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually);
VLOG(6) << " manually-schedule source code:\n" << source_code_manually;
CheckResult(GenExecutableKernel(build_module_auto),
GenExecutableKernel(build_module_manually),
input_names,
output_names,
{input_shape},
{input_shape, {1}},
target);
}
/* An op can be inlined into multiple producers at the same time.
*/
// TEST_F(TestAutoInline, InlineToMultiProducers) {
// TODO(6clc): Complete the unit test, once ReverseComputeInline is ready.
// }
#endif
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h"
#include <glog/logging.h>
#include <cstdlib>
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace auto_schedule {
static std::vector<int> auto_unroll_options = {0, 8, 32, 128};
bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
// whether any block has reduce iter
auto has_reduce_iter = [](const Expr* x) {
auto* block_realize = x->As<ir::ScheduleBlockRealize>();
if (block_realize) {
auto* schedule_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock";
for (auto&& var : schedule_block->iter_vars) {
if (var->is_reduce_axis) {
VLOG(6) << "find ScheduleBlockRealize:" << *x
<< " has reduce_axis:" << var;
return true;
}
}
}
return false;
};
// whether has any for-loop with non-serial type
auto has_nonserial_loop = [](const Expr* x) {
if (x->As<ir::For>() &&
x->As<ir::For>()->for_type() != ir::ForType::Serial) {
VLOG(6) << "find non-serial loop:" << *x;
return true;
}
return false;
};
auto find_target_exprs = ir::CollectIRNodesWithoutTensor(
schedule_block->body,
[&has_reduce_iter, &has_nonserial_loop](const Expr* x) {
return has_reduce_iter(x) || has_nonserial_loop(x);
});
return !find_target_exprs.empty();
}
RuleApplyType AutoUnroll::Init(ir::IRSchedule* ir_schedule) {
ir_schedule_ = ir_schedule;
auto block_realizes = ir_schedule_->GetAllBlocks();
// A schedule block can perform `auto_unroll` rule should meet two conditions:
// (1) it is a root block
// (2) MeetCondition returns true with it
applicable_schedule_blocks_.clear();
std::set<Expr> deduplicate_results;
for (size_t i = 0; i < block_realizes.size(); ++i) {
// find root block
Expr root_block = ir_schedule_->GetRootBlock(block_realizes[i]);
auto* block_realize = root_block.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block;
auto* schedule_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:"
<< Expr(block_realize);
if (MeetCondition(schedule_block)) {
deduplicate_results.emplace(root_block);
}
}
applicable_schedule_blocks_ = {deduplicate_results.begin(),
deduplicate_results.end()};
num_applicable_ = applicable_schedule_blocks_.size();
VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
void AutoUnroll::Apply(int index) {
CHECK_LT(index, applicable_schedule_blocks_.size())
<< "invalid apply index:" << index;
auto applied_block = applicable_schedule_blocks_.at(index);
int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()];
ir_schedule_->Annotate(
applied_block, ir::attr::auto_unroll_max_step, max_step);
return;
}
RuleApplyType AutoUnroll::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name);
Expr root_block = state->ir_schedule.GetRootBlock(block_expr);
auto* block_realize = root_block.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block;
auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:"
<< Expr(block_realize);
return MeetCondition(schedule_block) ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
std::vector<SearchState> AutoUnroll::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy();
Expr block_expr = new_state->ir_schedule.GetBlock(block_name);
Expr applied_block = new_state->ir_schedule.GetRootBlock(block_expr);
int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()];
new_state->ir_schedule.Annotate(
applied_block, ir::attr::auto_unroll_max_step, max_step);
return {new_state};
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
// This rule can be applied in a ScheduleBlock has reduce axis or has loops with
// non-serial type. As a result, it will set a attribute with key named
// ir::attr::auto_unroll_max_step and value indicating max permitted unrolled
// step in the applied ScheduleBlock. Finally, UnrollLoop pass will do unroll
// based on actual situation.
class AutoUnroll : public AutoGenRule {
public:
explicit AutoUnroll(const common::Target& target) : AutoGenRule(target) {}
~AutoUnroll() = default;
RuleApplyType Init(ir::IRSchedule* init_schedule) override;
void Apply(int index) override;
std::string GetRuleName() const override { return "AutoUnroll"; }
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private:
bool MeetCondition(const ir::ScheduleBlock* schedule_block) const;
private:
std::vector<Expr> applicable_schedule_blocks_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/lang/lower.h"
namespace cinn {
namespace auto_schedule {
TEST(AutoUnroll, Init) {
using namespace ir; // NOLINT
Expr M(100);
Expr N(4);
Placeholder<float> A("A", {M, N});
Placeholder<float> B("B", {M, N});
Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i, j) * B(i, j); }, "C");
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
auto stages = CreateStages({C});
auto funcs = cinn::lang::LowerVec(
"test_init", stages, {A, B, C}, {}, {}, nullptr, target, true);
auto ast_expr = funcs[0]->body;
ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr}));
AutoUnroll test_rule(target);
// not meet specific condition
ASSERT_EQ(test_rule.Init(&init_schedule), RuleApplyType::kCannotApply);
}
TEST(AutoUnroll, UnrollableApply) {
using namespace ir; // NOLINT
Expr M(100);
Expr N(4);
Expr K(32);
Placeholder<float> A("A", {M, K});
Placeholder<float> B("B", {K, N});
Var k(K.as_int32(), "k0");
Tensor C = Compute(
{M, N},
[&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
auto stages = CreateStages({C});
auto funcs = cinn::lang::LowerVec(
"test_unrollable", stages, {A, B, C}, {}, {}, nullptr, target, true);
auto ast_expr = funcs[0]->body;
auto* init_block_realize =
ast_expr.As<ir::Block>()->stmts.front().As<ir::ScheduleBlockRealize>();
auto* init_schedule_block =
init_block_realize->schedule_block.As<ir::ScheduleBlock>();
ASSERT_NE(init_schedule_block, nullptr);
ASSERT_TRUE(init_schedule_block->attrs.empty());
VLOG(6) << "Before auto-unroll:\n" << ast_expr;
AutoUnroll test_rule(target);
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
SearchState state(ir_schedule, 0, {});
ASSERT_EQ(test_rule.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(test_rule.NumberApplicable(), 1);
test_rule.ApplyRandomly();
// ApplyOnBlock
EXPECT_EQ(test_rule.AnalyseApplyType(state, "C"),
RuleApplyType::kApplyAndPruneOtherRules);
std::vector<cinn::auto_schedule::SearchState> states =
test_rule.ApplyOnBlock(state, "C");
auto test_func = [](IRSchedule* ir_sch) {
Expr applied_expr = ir_sch->GetModule().GetExprs().front();
auto* applied_block_realize = applied_expr.As<ir::Block>()
->stmts.front()
.As<ir::ScheduleBlockRealize>();
auto* applied_schedule_block =
applied_block_realize->schedule_block.As<ir::ScheduleBlock>();
ASSERT_FALSE(applied_schedule_block->attrs.empty());
EXPECT_EQ(
applied_schedule_block->attrs.count(ir::attr::auto_unroll_max_step), 1);
const auto& attr_value =
applied_schedule_block->attrs.at(ir::attr::auto_unroll_max_step);
const int* max_step = absl::get_if<int>(&attr_value);
EXPECT_NE(max_step, nullptr);
EXPECT_LE(*max_step, 128);
VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n"
<< ir_sch->GetModule().GetExprs().front();
};
test_func(&ir_schedule);
test_func(&states[0]->ir_schedule);
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "test/cpp/cinn/program_builder.h"
namespace cinn {
namespace auto_schedule {
class TestMixRules : public TestAutoGenRuleBase {
public:
std::vector<std::string> default_input_names = {"X", "Y"};
std::vector<std::string> default_output_names = {"temp_matmul_out"};
};
TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) {
frontend::Program matmul_op =
tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}});
Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op);
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];
// Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(
target_, MultiLevelTiling::kConfigs.at(target_.arch));
multi_level_tiling.Init(&ir_schedule);
ASSERT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly();
VLOG(6) << "after MultiLevelTiling Expr:\n" << func_bodys[0];
// build ir::Module and debug source code
auto ir_module = BuildIRModule(ir_schedule);
auto source_code = GenSourceCode(ir_module);
VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision
CheckResult(GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(
MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))),
default_input_names,
default_output_names,
{{32, 32}, {32, 32}},
{{32, 32}},
target_);
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include <glog/logging.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace auto_schedule {
MultiLevelTiling::MultiLevelTiling(const common::Target& target,
const Config& config)
: AutoGenRule(target), config_(config) {
for (int i = 0; i < config_.tile_struct.size(); ++i) {
if (config_.tile_struct[i] == 'S') {
s_indices_.push_back(i);
} else if (config_.tile_struct[i] == 'R') {
r_indices_.push_back(i);
} else {
CHECK(false) << "Illegal tiling structure string";
}
}
}
bool MultiLevelTiling::MeetCondition(
const ir::ScheduleBlockRealize& sche_block_realize) const {
return NeedsMultiLevelTiling(sche_block_realize);
}
RuleApplyType MultiLevelTiling::Init(ir::IRSchedule* ir_schedule) {
ir_schedule_ = ir_schedule;
all_block_realizes_ = ir_schedule_->GetAllBlocks();
applicable_indices_.clear();
num_applicable_ = 0;
for (size_t i = 0; i < all_block_realizes_.size(); ++i) {
ir::ScheduleBlockRealize* sche_block_realize =
all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
AnalyzeScheduleBlockReadWriteBuffer(
sche_block_realize->schedule_block.As<ir::ScheduleBlock>());
if (MeetCondition(*sche_block_realize)) {
++num_applicable_;
applicable_indices_.push_back(i);
}
}
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
void MultiLevelTiling::Apply(int index) {
CHECK(ir_schedule_ != nullptr) << "Run MultiLevelTiling::Apply without Init";
CHECK(num_applicable_ > 0 && applicable_indices_.size() == num_applicable_)
<< "MultiLevelTiling::Apply pre-condition doesn't meet";
CHECK(index >= 0 && num_applicable_ > index)
<< "Invalid index for MultiLevelTiling::Apply, the index needs 0 <= "
"index && index < NumberApplicable(), "
<< "Currently index = " << index
<< ", NumberApplicable() = " << num_applicable_;
int apply_index = applicable_indices_[index];
std::string block_name = all_block_realizes_[apply_index]
.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
Expr block_expr = all_block_realizes_[apply_index];
ApplyTiling(ir_schedule_, block_expr);
block_expr = ir_schedule_->GetBlock(block_name);
ApplyCacheRead(ir_schedule_, block_expr);
block_expr = ir_schedule_->GetBlock(block_name);
ApplyCacheWrite(ir_schedule_, block_expr);
VLOG(4) << "Returning the result of MultiLevelTiling";
return;
}
std::string MultiLevelTiling::GetRuleName() const { return "MultiLevelTiling"; }
RuleApplyType MultiLevelTiling::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name);
auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr;
AnalyzeScheduleBlockReadWriteBuffer(
block_realize->schedule_block.As<ir::ScheduleBlock>());
return NeedsMultiLevelTiling(*block_realize)
? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
std::vector<SearchState> MultiLevelTiling::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy();
ir::IRSchedule* ir_sch = &new_state->ir_schedule;
Expr block_expr = ir_sch->GetBlock(block_name);
ApplyTiling(ir_sch, block_expr);
block_expr = ir_sch->GetBlock(block_name);
ApplyCacheRead(ir_sch, block_expr);
block_expr = ir_sch->GetBlock(block_name);
ApplyCacheWrite(ir_sch, block_expr);
VLOG(4) << "Returning the result of MultiLevelTiling";
return {new_state};
}
void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule,
ir::Expr& block_expr) {
ir::ScheduleBlockRealize* sche_block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
tile_loops_.clear();
tile_loops_.resize(config_.tile_struct.size());
std::vector<Expr> for_exprs = ir_schedule->GetLoops(block_expr);
VLOG(5) << "The number of loops to split in MultiLevelTiling is "
<< for_exprs.size();
for (int i = for_exprs.size() - 1; i >= 0; --i) {
ir::For* ir_for = for_exprs[i].As<ir::For>();
VLOG(6) << "Applying Split for MultiLevelTiling on: " << Expr(ir_for);
const std::vector<int>* idx = nullptr;
if (sche_block->iter_vars[i]->is_reduce_axis) {
idx = &r_indices_;
} else {
idx = &s_indices_;
} // TODO(zhhsplendid): support more iterator variable types
int extent = ir_for->extent.as_int32(); // maybe int64?
int num_split = idx->size();
if (num_split > 1) {
std::vector<Expr> tile_split_factor =
ir_schedule->SamplePerfectTile(Expr(ir_for), num_split, 64);
std::vector<Expr> splited =
ir_schedule->Split(Expr(ir_for), tile_split_factor);
VLOG(6) << "Finish Split for MultiLevelTiling on above loop";
for (int j = 0; j < num_split; ++j) {
tile_loops_[idx->at(j)].push_back(splited[j]);
}
} else {
tile_loops_[idx->at(0)].push_back(for_exprs[i]);
}
}
VLOG(5) << "Finish Split in MultiLevelTiling, before Reorder.";
// Have to GetLoops again because Split can change Block Expr(s)
for_exprs = ir_schedule->GetLoops(sche_block->name);
std::unordered_map<std::string, int> loop_var_name_to_idx;
for (int i = 0; i < for_exprs.size(); ++i) {
loop_var_name_to_idx[for_exprs[i].As<ir::For>()->loop_var->name] = i;
}
CHECK(loop_var_name_to_idx.size() == for_exprs.size())
<< "Loops contain duplicate loop var names after split";
std::vector<Expr> splited_loops;
for (auto& t : tile_loops_) {
std::reverse(t.begin(), t.end());
for (auto& tile_loop_expr : t) {
const ir::For* tile_loop = tile_loop_expr.As<ir::For>();
CHECK(tile_loop) << "tiles store non For Expr";
int idx = loop_var_name_to_idx[tile_loop->loop_var->name];
splited_loops.push_back(for_exprs[idx]);
}
}
Expr reordered_expr = ir_schedule->Reorder(splited_loops);
VLOG(5) << "Finish Reorder in MultiLevelTiling, now do Fuse and Binding on "
"the main loop chain";
int num_binds = std::min(config_.bind_axis.size(), tile_loops_.size());
for (int i = 0; i < num_binds; ++i) {
loop_var_name_to_idx.clear();
for_exprs = ir_schedule->GetLoops(sche_block->name);
for (int j = 0; j < for_exprs.size(); ++j) {
loop_var_name_to_idx[for_exprs[j].As<ir::For>()->loop_var->name] = j;
}
CHECK(loop_var_name_to_idx.size() == for_exprs.size())
<< "Loops contain duplicate loop var names before Fusion";
// Some loops extent may exceed the limited max factor (For example,
// exceed the limit number of CUDA threads), here we check whether
// the fused loop extent, which is the production of extends of loops
// to be fused, is less or equal to the max factor.
//
// If yes, we fuse those loops and bind the fused loop
// If no, we bind the first loop whose extent is less than the factor.
int extent_prod = 1;
int first_idx_less_than_max_factor = -1;
for (int j = 0; j < tile_loops_[i].size(); ++j) {
const ir::For* tile_loop = tile_loops_[i][j].As<ir::For>();
CHECK(tile_loop) << "tiles store non For Expr";
int idx = loop_var_name_to_idx[tile_loop->loop_var->name];
tile_loops_[i][j] = for_exprs[idx];
int extent = tile_loop->extent.as_int32(); // maybe int64?
extent_prod *= extent;
if (first_idx_less_than_max_factor == -1 && extent <= max_factor_) {
first_idx_less_than_max_factor = idx;
}
}
if (extent_prod <= max_factor_) {
Expr fused = ir_schedule->Fuse(tile_loops_[i]);
ir_schedule->Bind(fused, config_.bind_axis[i]);
} else if (first_idx_less_than_max_factor != -1) {
ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor],
config_.bind_axis[i]);
}
}
VLOG(5) << "Do Fuse and Binding on the non-main loop chains";
Expr sche_block_top_loop = ir_schedule->GetLoops(sche_block->name)[0];
if (reordered_expr.As<ir::Block>()) {
for (Expr& top_loop : reordered_expr.As<ir::Block>()->stmts) {
if (top_loop != sche_block_top_loop) {
std::vector<Expr> scan_loop_blocks = ir_schedule->GetAllBlocks();
Expr other_loop_chain_schedule;
for (Expr& block : scan_loop_blocks) {
std::vector<Expr> loop_chain = ir_schedule->GetLoops(block);
if (loop_chain[0] == top_loop) {
other_loop_chain_schedule = block;
break;
}
}
if (!other_loop_chain_schedule.defined()) {
LOG(WARNING) << "Has non-main loop chain, but not corresponding "
"ScheduleBlock in MultiLevelTiling";
continue;
}
std::string other_loop_schedule_name =
other_loop_chain_schedule.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
VLOG(6) << "Found other_loop_schedule_name = "
<< other_loop_schedule_name;
int fuse_index = 0;
for (int i = 0; i < num_binds; ++i) {
for_exprs = ir_schedule->GetLoops(other_loop_schedule_name);
// Some loops extent may exceed the limited max factor (For example,
// exceed the limit number of CUDA threads), here we check whether
// the fused loop extent, which is the production of extends of loops
// to be fused, is less or equal to the max factor.
//
// If yes, we fuse those loops and bind the fused loop
// If no, we bind the first loop whose extent is less than the factor.
int extent_prod = 1;
int first_idx_less_than_max_factor = -1;
for (int j = 0; j < tile_loops_[i].size(); ++j) {
int extent =
for_exprs[fuse_index + j].As<ir::For>()->extent.as_int32();
extent_prod *= extent;
if (first_idx_less_than_max_factor == -1 && extent <= max_factor_) {
first_idx_less_than_max_factor = fuse_index + j;
}
}
if (extent_prod <= max_factor_) {
std::vector<Expr> loops_to_fuse(
for_exprs.begin() + fuse_index,
for_exprs.begin() + fuse_index + tile_loops_[i].size());
Expr fused = ir_schedule->Fuse(loops_to_fuse);
ir_schedule->Bind(fused, config_.bind_axis[i]);
fuse_index += 1;
} else if (first_idx_less_than_max_factor != -1) {
ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor],
config_.bind_axis[i]);
fuse_index += tile_loops_[i].size();
}
}
}
}
}
}
void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule,
ir::Expr& block_expr) {
ir::ScheduleBlockRealize* sch_block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sch_block =
sch_block_realize->schedule_block.As<ir::ScheduleBlock>();
std::string block_name = sch_block->name;
// Analyze which buffers can be cached
std::vector<int> read_buffer_indexes;
for (int i = 0; i < sch_block->read_buffers.size(); ++i) {
bool is_read_write = false;
for (int j = 0; j < sch_block->write_buffers.size(); ++j) {
if (sch_block->read_buffers[i] == sch_block->write_buffers[j]) {
is_read_write = true;
break;
}
}
if (!is_read_write) {
read_buffer_indexes.push_back(i);
}
}
// Schedule
for (int read_buffer_index : read_buffer_indexes) {
for (int level : config_.read_cache_levels) {
// 1.find target loop
const auto loops = tile_loops_.at(level - 1);
if (loops.size() == 0) {
continue;
}
// 2.Do CacheRead and get the cache block
ir::Expr cache_block = ir_schedule->CacheRead(
block_expr, read_buffer_index, config_.read_cache_memory_type);
std::string cache_block_name =
cache_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
std::string target_for_loop_name =
loops.back().As<ir::For>()->loop_var->name;
// 3.Place the cache_block under target_for_loop
// The original block expr is invalid after the CacheRead schedule,
// so we reacquire the block expr after the schedule according to the
// block name
block_expr = ir_schedule->GetBlock(block_name);
std::vector<Expr> for_exprs = ir_schedule->GetLoops(block_expr);
for (const Expr& for_expr : for_exprs) {
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) !=
std::string::npos) {
ir_schedule->ComputeAt(cache_block, for_expr, true);
break;
}
}
// 4.Threads under the same block cooperative fetch data from global
// memory.
Expr new_cache_block = ir_schedule->GetBlock(cache_block_name);
auto cache_block_loops = ir_schedule->GetLoops(new_cache_block);
std::vector<std::string> compute_at_extra_var = utils::Split(
absl::get<std::string>(new_cache_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->attrs.at("compute_at_extra_var")),
",");
std::vector<Expr> buffer_loops;
// int nthreads = 1;
for (const Expr& for_expr : cache_block_loops) {
if (std::find(compute_at_extra_var.begin(),
compute_at_extra_var.end(),
for_expr.As<ir::For>()->loop_var->name) !=
compute_at_extra_var.end()) {
buffer_loops.push_back(for_expr);
}
}
auto fused_buffer_loop = ir_schedule->Fuse(buffer_loops);
// TODO(BiynXu): Implement vectorize fetching data and pass in vector
// length
ir_schedule->Annotate(ir_schedule->GetBlock(cache_block_name),
ir::attr::cooperative_process,
0);
}
}
}
void MultiLevelTiling::ApplyCacheWrite(ir::IRSchedule* ir_schedule,
ir::Expr& block_expr) {
ir::Expr cache_block =
ir_schedule->CacheWrite(block_expr, 0, config_.write_cache_memory_type);
for (int level : config_.write_cache_levels) {
const auto loops = tile_loops_.at(level - 1);
if (loops.size() == 0) {
continue;
}
std::string target_for_loop_name =
loops.back().As<ir::For>()->loop_var->name;
// Because the block name is changed in CacheWrite, we need to calculate the
// derived name according to the logic of CacheWrite and find the loop
// structure according to the derived name.
const std::string original_block_name =
block_expr.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
const std::string derivative_block_name = original_block_name + "_" +
config_.write_cache_memory_type +
"_temp_buffer";
std::vector<Expr> for_exprs = ir_schedule->GetLoops(derivative_block_name);
for (const Expr& for_expr : for_exprs) {
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) !=
std::string::npos) {
ir_schedule->ReverseComputeAt(
ir_schedule->GetBlock(original_block_name), for_expr, true);
}
}
const std::string reduce_init_block_name =
original_block_name + "__reduce_init";
for_exprs = ir_schedule->GetLoops(derivative_block_name);
for (const Expr& for_expr : for_exprs) {
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) !=
std::string::npos &&
ir_schedule->HasBlock(reduce_init_block_name)) {
ir_schedule->SimpleComputeAt(
ir_schedule->GetBlock(reduce_init_block_name), for_expr);
}
}
}
}
const std::unordered_map<common::Target::Arch, MultiLevelTiling::Config>
MultiLevelTiling::kConfigs{
{common::Target::Arch::NVGPU,
MultiLevelTiling::Config{
/*bind_axis*/ std::vector<std::string>{"blockIdx.x",
"threadIdx.x"},
/*tile_struct*/ std::string("SSSRRSRS"),
/*read_cache_memory_type*/ std::string("shared"),
/*read_cache_levels*/ std::vector<int>{4},
/*write_cache_memory_type*/ std::string("local"),
/*write_cache_levels*/ std::vector<int>{3},
}},
{common::Target::Arch::X86,
MultiLevelTiling::Config{
/*bind_axis*/ std::vector<std::string>{},
/*tile_struct*/ std::string("SSRSRS"),
/*read_cache_memory_type*/ std::string("local"),
/*read_cache_levels*/ std::vector<int>{3},
/*write_cache_memory_type*/ std::string("local"),
/*write_cache_levels*/ std::vector<int>{2},
}}};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
class MultiLevelTiling : public AutoGenRule {
public:
struct Config {
// Which thread axis each tiled loop is bound to
std::vector<std::string> bind_axis;
// Use char 'S' and 'R' to represent tile structure.
// S means space tiling level and R means reduce tiling level
//
// For example, if tile_struct_ = "SSRSRS" and we are doing matrix
// multiplication, i, j are the spatial indices and k is the reduce index,
// the tiling result will be i_0, j0, i1, j1, k0, i2, j2, k1, i3, j3
std::string tile_struct;
// The storage type of read cache
std::string read_cache_memory_type;
// Which tiled levels are read cache block inserted at
std::vector<int> read_cache_levels;
// The storage type of write cache
std::string write_cache_memory_type;
// Which tiled levels are write cache block inserted at
std::vector<int> write_cache_levels;
};
static const std::unordered_map<common::Target::Arch, Config> kConfigs;
MultiLevelTiling(const common::Target& target, const Config& config);
~MultiLevelTiling() = default;
// initialize the AutoGenRule, it must be called before further actions.
// Returns false if the rule cannot be applied on the mod_expr, true otherwise
RuleApplyType Init(ir::IRSchedule* init_schedule) override;
// Applies rule on the ir::ModuleExpr for a schedule block specified by index
// between 0 (inclusive) and NumberApplicable() (exclusive)
void Apply(int index) override;
// Returns the name of the rule, used for debug.
std::string GetRuleName() const override;
// Returns true if sche_block_realize is applicable by MultiLevelTiling
bool MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const;
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
// Sample pair of integer type (a, b) such as a * b = extent
template <typename T>
std::vector<T> SampleSplitTwo(T extent) const {
std::vector<std::vector<T>> candidates;
for (T div = 1; div <= sqrt(extent); ++div) {
if (extent % div == 0) {
candidates.push_back({T(div), extent / div});
}
}
if (candidates.size() == 0) {
return {1, T(extent)};
}
int index = rand() % candidates.size(); // NOLINT
std::vector<T> pick = candidates[index];
if (rand() % 2 != 0) { // NOLINT
T tmp = pick[0];
pick[0] = pick[1];
pick[1] = tmp;
}
return pick;
}
// Sample num_split integers whose product equals extent
template <typename T>
std::vector<T> SampleTileSplit(T extent, int num_split) const {
CHECK_GT(num_split, 0)
<< "num_split in SampleTileSplit must be greater than 0";
if (num_split == 1) {
return {extent};
}
std::vector<T> two_split = SampleSplitTwo<T>(extent);
if (num_split == 2) {
return two_split;
}
int half = num_split >> 1;
std::vector<T> result = SampleTileSplit<T>(two_split[0], half);
std::vector<T> remind = SampleTileSplit<T>(two_split[1], num_split - half);
result.insert(result.end(), remind.begin(), remind.end());
return result;
}
private:
void ApplyTiling(ir::IRSchedule* ir_schedule,
ir::Expr& block_expr); // NOLINT
void ApplyCacheRead(ir::IRSchedule* ir_schedule,
ir::Expr& block_expr); // NOLINT
void ApplyCacheWrite(ir::IRSchedule* ir_schedule,
ir::Expr& block_expr); // NOLINT
private:
std::vector<ir::Expr> all_block_realizes_;
std::vector<int> applicable_indices_;
Config config_;
std::vector<int> s_indices_;
std::vector<int> r_indices_;
std::vector<std::vector<ir::Expr>> tile_loops_;
// A factor to limit the split factor within max thread number per block
int max_factor_ = 1024;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstdlib>
#include <iostream>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/utils/string.h"
#include "test/cpp/cinn/program_builder.h"
namespace cinn {
namespace auto_schedule {
TEST(MultiLevelTile, SampleSplitTwo) {
srand(0);
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
for (int i = 0; i < 100; ++i) {
size_t number_to_split =
rand() % 65535 + 2; // NOLINT, random number in [2, 2^16]
std::vector<size_t> split =
multi_level_tiling.SampleSplitTwo<size_t>(number_to_split);
EXPECT_EQ(split.size(), 2UL);
EXPECT_EQ(split[0] * split[1], number_to_split);
}
}
TEST(MultiLevelTile, SampleTileSplit) {
srand(0);
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
for (int i = 0; i < 100; ++i) {
int number_to_split =
rand() % 65535 + 2; // NOLINT, random number in [2, 2^16]
int split_size = rand() % 5 + 1; // NOLINT, random in [1, 5]
std::vector<int> split =
multi_level_tiling.SampleTileSplit<int>(number_to_split, split_size);
EXPECT_EQ(split.size(), static_cast<size_t>(split_size));
int product = 1;
for (int num : split) {
product *= num;
}
EXPECT_EQ(product, number_to_split);
}
}
TEST(MultiLevelTile, SimpleLoops) {
srand(0);
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
Expr M(32);
Expr N(128);
Placeholder<float> A("A", {M});
Placeholder<float> B("B", {N});
ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMultiLevelTile_SimpleLoops",
stages,
{C},
{},
{},
nullptr,
target,
true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: ";
VLOG(6) << ast_expr;
MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
SearchState state(ir_schedule, 0, {});
EXPECT_EQ(multi_level_tiling.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly();
// ApplyOnBlock
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = multi_level_tiling.ApplyOnBlock(state, "C");
auto test_func = [](ir::IRSchedule* ir_sch) {
std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ss;
ss << exprs[0];
std::string expr_str = ss.str();
VLOG(6) << expr_str;
};
test_func(&ir_schedule);
test_func(&new_states[0]->ir_schedule);
}
// TODO(SunNy820828449): fix in future
/*
TEST(MulitLevelTile, MatrixMultiply) {
srand(0);
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
Expr M(32);
Expr N(32);
Expr K(32);
Placeholder<float> A("A", {M, K});
Placeholder<float> B("B", {K, N});
Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMultiLevelTile_MatrixMultiply", stages, {C}, {}, {},
nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: ";
VLOG(6) << ast_expr;
MultiLevelTiling multi_level_tiling(target,
MultiLevelTiling::kConfigs.at(target.arch)); ir::IRSchedule
ir_schedule(ir::ModuleExpr({ast_expr})); SearchState state(ir_schedule, 0, {});
EXPECT_EQ(multi_level_tiling.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly();
// ApplyOnBlock
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"),
RuleApplyType::kApplyAndPruneOtherRules); auto new_states =
multi_level_tiling.ApplyOnBlock(state, "C");
auto test_func = [](ir::IRSchedule* ir_sch) {
std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ss;
ss << exprs[0];
std::string expr_str = ss.str();
VLOG(6) << expr_str;
};
test_func(&ir_schedule);
test_func(&new_states[0]->ir_schedule);
}
*/
class TestMultiLevelTiling : public TestAutoGenRuleBase {
public:
int fixed_rand_seed = 1;
std::vector<std::string> default_input_names;
std::vector<std::string> default_output_names;
};
TEST_F(TestMultiLevelTiling, Matmul) {
default_input_names = {"X", "Y"};
default_output_names = {"temp_matmul_out"};
std::vector<int32_t> X_shape = {32, 32};
std::vector<int32_t> Y_shape = {32, 32};
std::vector<int32_t> out_shape = {32, 32};
Initialize(common::DefaultNVGPUTarget());
frontend::Program matmul_op =
tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}});
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed);
SearchState state(ir_schedule);
VLOG(6) << "Original state:\n" << state->DebugString();
// Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(
target_, MultiLevelTiling::kConfigs.at(target_.arch));
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states =
multi_level_tiling.ApplyOnBlock(state, default_output_names[0]);
VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString();
std::string ir = GetIR(new_states[0]->ir_schedule);
std::string expected_ir = R"ROC(Expr 0 {
{
ScheduleBlock(root)
{
{
thread_bind[blockIdx.x] for (i_j_fused, 0, 4)
{
thread_bind[threadIdx.x] for (i_0_j_0_fused, 0, 1)
{
serial for (i_1, 0, 1)
{
serial for (j_1, 0, 1)
{
serial for (i_2, 0, 1)
{
serial for (j_2, 0, 1)
{
serial for (i_3, 0, 8)
{
serial for (j_3, 0, 32)
{
ScheduleBlock(temp_matmul_out__reduce_init)
{
i0, i1 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)))
{
temp_matmul_out__reduce_init[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = 0.00000000f
}
}
}
}
}
}
{
serial for (reduce_k_0, 0, 4)
{
serial for (ax0_0_ax1_0_fused, 0, 256)
{
ScheduleBlock(Y_reshape_shared_temp_buffer)
{
v0, v1 = axis.bind(((ax0_0_ax1_0_fused / 32) + (8 * reduce_k_0)), ((ax0_0_ax1_0_fused % 32) + (32 * j_1)))
attrs(compute_at_extra_var:ax0_0,ax1_0, cooperative_process:0)
{
Y_reshape_shared_temp_buffer[v0, v1] = Y_reshape[v0, v1]
}
}
}
serial for (ax0_ax1_fused, 0, 64)
{
ScheduleBlock(X_reshape_shared_temp_buffer)
{
v0, v1 = axis.bind(((ax0_ax1_fused / 8) + ((8 * i_0_j_0_fused) + ((8 * i_1) + (8 * i_j_fused)))), ((ax0_ax1_fused % 8) + (8 * reduce_k_0)))
attrs(compute_at_extra_var:ax0,ax1, cooperative_process:0)
{
X_reshape_shared_temp_buffer[v0, v1] = X_reshape[v0, v1]
}
}
}
serial for (reduce_k_1, 0, 1)
{
serial for (i_2, 0, 1)
{
serial for (j_2, 0, 1)
{
serial for (reduce_k_2, 0, 8)
{
serial for (i_3, 0, 8)
{
serial for (j_3, 0, 32)
{
ScheduleBlock(temp_matmul_out_local_temp_buffer)
{
i0_0, i1_0, i2 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)))
read_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)], _X[i(undefined:undefined), reduce_k(undefined:undefined)], _Y[reduce_k(undefined:undefined), j(undefined:undefined)])
write_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)])
{
temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = (temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] + (X_reshape_shared_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2))] * Y_reshape_shared_temp_buffer[((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)), ((32 * j_1) + ((32 * j_2) + j_3))]))
}
}
}
}
}
}
}
}
}
serial for (ax0_1, 0, 8)
{
serial for (ax1_1, 0, 32)
{
ScheduleBlock(temp_matmul_out)
{
v0, v1 = axis.bind((((8 * i_0_j_0_fused) + ((8 * i_1) + (8 * i_j_fused))) + ax0_1), ((32 * j_1) + ax1_1))
attrs(reverse_compute_at_extra_var:ax0_1,ax1_1)
{
temp_matmul_out[v0, v1] = temp_matmul_out_local_temp_buffer[v0, v1]
}
}
}
}
}
}
}
}
}
}
}
}
} // end Expr 0
)ROC";
ASSERT_EQ(ir, expected_ir);
// build ir::Module and debug source code
auto ir_module = BuildIRModule(new_states[0]->ir_schedule);
auto source_code = GenSourceCode(ir_module);
VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision
CheckResult(
GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(
matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names,
default_output_names,
{X_shape, Y_shape},
{out_shape},
target_);
}
TEST_F(TestMultiLevelTiling, ReduceSum) {
default_input_names = {"X"};
default_output_names = {"var_0_tmp"};
std::vector<int32_t> X_shape = {1, 16, 32};
std::vector<int32_t> out_shape = {1, 16, 1};
std::vector<int32_t> reduce_dim = {2};
Initialize(common::DefaultNVGPUTarget());
frontend::Program reduce_sum_op =
tests::OpBuilder("reduce_sum")
.Build({{"X", X_shape}}, {{"dim", reduce_dim}, {"keep_dim", false}});
ir::IRSchedule ir_schedule = MakeIRSchedule(reduce_sum_op);
SearchState state(ir_schedule);
VLOG(6) << "Original state:\n" << state->DebugString();
// Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(
target_, MultiLevelTiling::kConfigs.at(target_.arch));
// EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state,
// default_output_names[0]), RuleApplyType::kCannotApply);
}
TEST_F(TestMultiLevelTiling, Pool2d) {
default_input_names = {"input"};
default_output_names = {"var_0", "pad_temp_0"};
std::vector<std::vector<int32_t>> input_shapes{{2, 8, 16, 16}};
std::vector<std::vector<int32_t>> output_shapes{{2, 8, 8, 8}, {2, 8, 18, 18}};
std::string pooling_type = "max";
std::vector<int> ksize{3, 3};
std::vector<int> strides{2, 2};
std::vector<int> paddings{1, 1, 1, 1};
bool ceil_mode = false;
bool exclusive = true;
bool global_pooling = false;
std::string data_format = "NCHW";
bool adaptive = false;
std::string padding_algorithm = "EXPLICIT";
frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build(
{{"input", input_shapes[0]}},
{{"pool_type", pooling_type},
{"kernel_size", ksize},
{"stride_size", strides},
{"padding_size", paddings},
{"ceil_mode", ceil_mode},
{"exclusive", exclusive},
{"global_pooling", global_pooling},
{"data_format", data_format},
{"adaptive", adaptive},
{"padding_algorithm", padding_algorithm}});
Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(pool2d_program, fixed_rand_seed);
SearchState state(ir_schedule);
VLOG(6) << "Original state:\n" << state->DebugString();
// Apply MultiLevelTiling
MultiLevelTiling::Config mlt_config = {
/*bind_axis*/ std::vector<std::string>{"blockIdx.x", "threadIdx.x"},
/*tile_struct*/ std::string("SSRS"),
/*read_cache_memory_type*/ std::string("shared"),
/*read_cache_levels*/ std::vector<int>{3},
/*write_cache_memory_type*/ std::string("local"),
/*write_cache_levels*/ std::vector<int>{2},
};
MultiLevelTiling multi_level_tiling(target_, mlt_config);
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states =
multi_level_tiling.ApplyOnBlock(state, default_output_names[0]);
VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString();
std::string ir = GetIR(new_states[0]->ir_schedule);
std::string expected_ir = R"ROC(Expr 0 {
{
ScheduleBlock(root)
{
{
serial for (i, 0, 2)
{
serial for (j, 0, 8)
{
serial for (k, 0, 18)
{
serial for (a, 0, 18)
{
ScheduleBlock(pad_temp_0)
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
}
}
}
}
}
}
{
thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{
thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4)
{
serial for (i_1, 0, 1)
{
serial for (j_1, 0, 4)
{
serial for (k_1, 0, 1)
{
serial for (a_1, 0, 4)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
}
}
}
}
}
}
{
serial for (kernel_idx, 0, 3)
{
serial for (kernel_idx_0, 0, 3)
{
serial for (ax0_ax1_ax2_ax3_fused, 0, 28)
{
ScheduleBlock(pad_temp_0_shared_temp_buffer)
{
v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0)))
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
{
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3]
}
}
}
serial for (i_1, 0, 1)
{
serial for (j_1, 0, 4)
{
serial for (k_1, 0, 1)
{
serial for (a_1, 0, 4)
{
ScheduleBlock(var_0_local_temp_buffer)
{
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
}
}
}
}
}
}
}
}
serial for (ax0_0, 0, 1)
{
serial for (ax1_0, 0, 4)
{
serial for (ax2_0, 0, 1)
{
serial for (ax3_0, 0, 4)
{
ScheduleBlock(var_0)
{
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
}
}
}
}
}
}
}
}
}
}
}
}
}
} // end Expr 0
)ROC";
ASSERT_EQ(ir, expected_ir);
// build ir::Module and debug source code
auto ir_module = BuildIRModule(new_states[0]->ir_schedule);
auto source_code = GenSourceCode(ir_module);
VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision
CheckResult(
GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(
pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names,
default_output_names,
input_shapes,
output_shapes,
target_);
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h"
#include <string>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
namespace cinn {
namespace auto_schedule {
SkipRule::SkipRule(const common::Target& target) : AutoGenRule(target) {}
RuleApplyType SkipRule::Init(ir::IRSchedule* ir_schedule) {
ir_schedule_ = ir_schedule;
num_applicable_ = 1;
return RuleApplyType::kApply;
}
std::string SkipRule::GetRuleName() const { return "SkipRule"; }
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
class SkipRule : public AutoGenRule {
public:
explicit SkipRule(const common::Target& target);
~SkipRule() = default;
RuleApplyType Init(ir::IRSchedule* init_schedule) override;
void Apply(int index) override {}
std::string GetRuleName() const override;
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override {
return RuleApplyType::kApply;
}
std::vector<SearchState> ApplyOnBlock(
SearchState state, const std::string& block_name) override {
return {state};
}
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstdlib>
#include <iostream>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h"
namespace cinn {
namespace auto_schedule {
TEST(SkipRule, Basic) {
srand(0);
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
Expr M(32);
Expr N(128);
Placeholder<float> A("A", {M});
Placeholder<float> B("B", {N});
ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: ";
VLOG(6) << ast_expr;
SkipRule skip_rule(target);
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
SearchState state(ir_schedule, 0, {});
EXPECT_EQ(skip_rule.Init(&ir_schedule), RuleApplyType::kApply);
EXPECT_EQ(skip_rule.NumberApplicable(), 1);
skip_rule.ApplyRandomly();
// ApplyOnBlock
EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply);
std::vector<cinn::auto_schedule::SearchState> states =
skip_rule.ApplyOnBlock(state, "C");
auto test_func = [&ast_expr](ir::IRSchedule* ir_sch) {
std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
EXPECT_EQ(ast_expr, exprs[0]);
};
test_func(&ir_schedule);
test_func(&states[0]->ir_schedule);
}
TEST(SkipRule, ApplyOnSpecificBlock) {
srand(0);
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
Expr M(32);
Expr N(128);
Placeholder<float> A("A", {M});
Placeholder<float> B("B", {N});
ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: ";
VLOG(6) << ast_expr;
SkipRule skip_rule(target);
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
SearchState state(ir_schedule, 0, {});
EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply);
std::vector<cinn::auto_schedule::SearchState> states =
skip_rule.ApplyOnBlock(state, "C");
std::vector<ir::Expr> exprs = states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
EXPECT_EQ(ast_expr, exprs[0]);
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <memory.h>
#include <stdlib.h>
#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#ifdef CINN_WITH_CUDA
#include <cuda_runtime.h>
#endif
namespace cinn {
namespace auto_schedule {
using ::cinn::hlir::framework::Instruction;
using ::cinn::hlir::framework::Scope;
using ::cinn::hlir::framework::Shape;
using ::cinn::hlir::framework::Tensor;
void TestAutoGenRuleBase::Initialize(const common::Target& target) {
target_ = target;
backend_compier_ = backends::Compiler::Create(target);
}
ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
const frontend::Program& test_program,
utils::LinearRandomEngine::StateType rand_seed,
bool apply_manual_schedule) {
Context::Global().ResetNameId();
auto graph = std::make_shared<hlir::framework::Graph>(test_program, target_);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
LOG_IF(WARNING, graph->fusion_groups.size() > 1)
<< "Test Graph has more than 1 group";
auto& dtype_dict =
graph->GetMutableAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
auto& shape_dict = graph->GetMutableAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_);
lowered_funcs_ =
op_lowerer.Lower(graph->fusion_groups.front(),
/*apply_op_schedule = */ apply_manual_schedule,
/*apply_group_schedule = */ apply_manual_schedule);
CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty";
std::vector<Expr> bodys;
for (auto&& func : lowered_funcs_) {
bodys.emplace_back(func->body);
}
return ir::IRSchedule(ir::ModuleExpr({std::move(bodys)}), rand_seed);
}
std::string TestAutoGenRuleBase::GetIR(const ir::IRSchedule& schedule) {
const auto& exprs = schedule.GetModule().GetExprs();
std::stringstream module_stream;
for (auto i = 0; i < exprs.size(); ++i) {
module_stream << "Expr " << i << " {\n"
<< exprs.at(i) << "\n} // end Expr " << i << "\n";
}
return module_stream.str();
}
ir::Module TestAutoGenRuleBase::BuildIRModule(const ir::IRSchedule& schedule) {
auto&& updated_bodys = schedule.GetModule().GetExprs();
CHECK_EQ(lowered_funcs_.size(), updated_bodys.size())
<< "associated exprs size not equal";
ir::Module::Builder builder("test_bulder", this->target_);
for (int i = 0; i < lowered_funcs_.size(); ++i) {
ir::Expr func_body = updated_bodys.at(i);
const ir::LoweredFunc& ori_func = lowered_funcs_.at(i);
auto&& new_func = UpdateFuncWithNewBody(target_, ori_func, func_body);
builder.AddFunction(new_func);
}
return builder.Build();
}
std::string TestAutoGenRuleBase::GenSourceCode(const ir::Module& ir_module) {
std::unique_ptr<backends::CodeGenC> codegen;
#ifdef CINN_WITH_CUDA
if (target_ == common::DefaultNVGPUTarget()) {
codegen = std::make_unique<backends::CodeGenCUDA_Dev>(this->target_);
} else {
codegen = std::make_unique<backends::CodeGenCX86>(
this->target_, CodeGenCX86::Feature::AVX512);
}
#else
codegen = std::make_unique<backends::CodeGenCX86>(
this->target_, CodeGenCX86::Feature::AVX512);
#endif
codegen->SetInlineBuiltinCodes(false);
return codegen->Compile(ir_module, CodeGenC::OutputKind::CImpl);
}
raw_func_type TestAutoGenRuleBase::GenExecutableKernel(
const ir::Module& ir_module) {
auto&& func_name = lowered_funcs_.front()->name;
// Compile to machine code
backend_compier_->Build(ir_module);
auto test_func_ptr = reinterpret_cast<void (*)(void**, int32_t)>(
backend_compier_->Lookup(func_name));
return test_func_ptr;
}
void MemoryCopy(const float* src, float* dst, int numel, std::string type) {
#ifdef CINN_WITH_CUDA
if (type == "DeviceToHost") {
cudaMemcpy(dst, src, numel * sizeof(float), cudaMemcpyDeviceToHost);
return;
} else if (type == "HostToDevice") {
cudaMemcpy(dst, src, numel * sizeof(float), cudaMemcpyHostToDevice);
return;
}
#endif
if (type == "HostToHost") {
for (size_t i = 0; i < numel; ++i) {
dst[i] = src[i];
}
} else {
LOG(FATAL) << "Unknown memory copy type";
}
}
void AddDataToScope(Scope* scope,
const common::Target& target,
float* data_ptr,
std::string name,
const std::vector<int>& shape) {
auto* var = scope->Var<Tensor>(name);
auto& tensor = absl::get<Tensor>(*var);
CHECK(shape.size()) << "The size of shape can not be 0.";
Shape cinn_shape(shape);
tensor->Resize(cinn_shape);
auto* tgt_data_ptr = tensor->mutable_data<float>(target);
std::string mem_cpy_type =
target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost";
MemoryCopy(data_ptr, tgt_data_ptr, cinn_shape.numel(), mem_cpy_type);
}
void CheckResult(raw_func_type test_func,
raw_func_type expected_func,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int>>& input_shapes,
const std::vector<std::vector<int>>& output_shapes,
const common::Target& target) {
CHECK(input_names.size()) << "The number of inputs must be greater than 0.";
CHECK(output_names.size()) << "The number of outputs must be greater than 0.";
CHECK_EQ(input_names.size(), input_shapes.size())
<< "The quantity of input_names and input_shapes must be equal.";
CHECK_EQ(output_names.size(), output_shapes.size())
<< "The quantity of output_names and output_shapes must be equal.";
// Initialize data
std::vector<float*> input_data_ptrs(input_names.size());
for (int i = 0; i < input_shapes.size(); ++i) {
int input_data_numel = std::accumulate(
input_shapes[i].begin(), input_shapes[i].end(), 1, [](int a, int b) {
return a * b;
});
input_data_ptrs[i] =
reinterpret_cast<float*>(malloc(input_data_numel * sizeof(float)));
for (int j = 0; j < input_data_numel; ++j) {
input_data_ptrs[i][j] = (rand() * 1.f) / RAND_MAX; // NOLINT
}
}
std::vector<float*> test_output_data_ptrs(output_names.size());
std::vector<float*> expected_output_data_ptrs(output_names.size());
std::vector<int> output_data_numels(output_shapes.size());
for (int i = 0; i < output_shapes.size(); ++i) {
output_data_numels[i] = std::accumulate(
output_shapes[i].begin(), output_shapes[i].end(), 1, [](int a, int b) {
return a * b;
});
test_output_data_ptrs[i] =
reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float)));
memset(test_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float));
expected_output_data_ptrs[i] =
reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float)));
memset(
expected_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float));
}
auto launch_kernel_fn = [&](raw_func_type& raw_func,
std::vector<float*>& output_data_ptrs) {
// Initialize scope
Scope scope;
// Initialize input data in scope.
for (int i = 0; i < input_names.size(); ++i) {
AddDataToScope(
&scope, target, input_data_ptrs[i], input_names[i], input_shapes[i]);
}
// Initialize output data in scope.
for (int i = 0; i < output_names.size(); ++i) {
AddDataToScope(&scope,
target,
output_data_ptrs[i],
output_names[i],
output_shapes[i]);
}
// Create Instruction and run
Instruction instr(target, &scope, input_names, output_names);
CHECK(raw_func) << "The raw_func can not be nullptr.";
instr.SetLoweredFunc(reinterpret_cast<void*>(raw_func));
// should call Finalize explicitly before Run
instr.Finalize();
instr.Run();
// data
for (int i = 0; i < output_names.size(); ++i) {
const float* result_ptr = scope.GetTensor(output_names[i])->data<float>();
std::string mem_cpy_type = target == common::DefaultNVGPUTarget()
? "DeviceToHost"
: "HostToHost";
MemoryCopy(
result_ptr, output_data_ptrs[i], output_data_numels[i], mem_cpy_type);
}
};
// launch and execute test and expected kernel separately
launch_kernel_fn(test_func, test_output_data_ptrs);
launch_kernel_fn(expected_func, expected_output_data_ptrs);
// Check result
for (int i = 0; i < output_shapes.size(); ++i) {
for (int j = 0; j < output_data_numels[i]; ++j) {
ASSERT_NEAR(
test_output_data_ptrs[i][j], expected_output_data_ptrs[i][j], 1e-4);
}
}
// Free memory
for (auto ptr : input_data_ptrs) {
free(ptr);
}
for (auto ptr : test_output_data_ptrs) {
free(ptr);
}
for (auto ptr : expected_output_data_ptrs) {
free(ptr);
}
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/cinn/backends/compiler.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/utils/random_engine.h"
namespace cinn {
namespace auto_schedule {
/* @brief: Function pointer of executable code compiled by CINN.
* @params-1: Pointers to all arguments, including input and output.
* @params-2: The number of Arguments.
* @return: void
*/
using raw_func_type = void (*)(void**, int32_t);
// A base utility class for testing AutoGenRule
class TestAutoGenRuleBase : public ::testing::Test {
public:
void SetUp() override {
srand(0);
Context::Global().ResetNameId();
}
// Initialize context for specified target
void Initialize(const common::Target& target);
// construct an ir::IRSchedule by lowering the specified for following
// AutoGenRule test
ir::IRSchedule MakeIRSchedule(
const frontend::Program& test_program,
utils::LinearRandomEngine::StateType rand_seed = -1,
bool apply_manual_schedule = false);
// Get the IR of bodies in IRSchedule
std::string GetIR(const ir::IRSchedule& schedule);
// build ir::Module from the original lowered funcs with their bodies updated
// by the schedule
ir::Module BuildIRModule(const ir::IRSchedule& schedule);
// generate source code with the built ir module
std::string GenSourceCode(const ir::Module& ir_module);
// generate executable kernel function with the built ir module
raw_func_type GenExecutableKernel(const ir::Module& ir_module);
protected:
common::Target target_;
std::vector<ir::LoweredFunc> lowered_funcs_;
std::unique_ptr<backends::Compiler> backend_compier_;
};
/* @brief: Interface for checking function correctness.
* @params-1: Function pointer of the function to be tested.
* @params-2: Expected function pointer for comparison.
* @params-3: Names of input data.
* @params-4: Names of output data.
* @params-5: Shapes of the input data, each input corresponds to a
* std::vector<int>.
* @params-6: Shapes of the output data, each output corresponds to a
* std::vector<int>.
* @params-7: The Target expressing computing platform and architecture of the
* function to be tested.
* @return: void
*/
void CheckResult(raw_func_type test_func,
raw_func_type expected_func,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int>>& input_shapes,
const std::vector<std::vector<int>>& output_shapes,
const common::Target& target);
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/block_sampler.h"
#include <algorithm>
#include "paddle/cinn/ir/ir.h"
namespace cinn {
namespace auto_schedule {
std::unique_ptr<BlockSampler> BlockSampler::Make(
const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy,
const std::string& strategy,
utils::LinearRandomEngine::StateType rand_seed,
const std::vector<int>& weights) {
CHECK_GT(all_blocks.size(), 0) << "Empty block list";
if (strategy == "traversal") {
VLOG(6) << "Init TraversalBlockSampler with block num = "
<< all_blocks.size();
return std::make_unique<TraversalBlockSampler>(all_blocks,
default_remove_policy);
} else if (strategy == "probabilistic") {
VLOG(6) << "Init ProbabilisticBlockSampler with block num = "
<< all_blocks.size();
return std::make_unique<ProbabilisticBlockSampler>(
all_blocks, default_remove_policy, rand_seed, weights);
}
LOG(FATAL) << "Unimplemented strategy:" << strategy;
return nullptr;
}
BlockSampler::BlockSampler(const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy) {
default_remove_policy_ = default_remove_policy;
std::transform(all_blocks.begin(),
all_blocks.end(),
std::back_inserter(all_blocks_),
[](const ir::Expr& block_expr) {
const ir::ScheduleBlockRealize* block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
const ir::ScheduleBlock* block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
return block->name;
});
}
std::string TraversalBlockSampler::NextBlock(bool remove) {
if (cur_idx_ < all_blocks_.size()) {
VLOG(6) << "[TraversalBlockSampler] next block: "
<< all_blocks_.at(cur_idx_);
std::string block_name = all_blocks_.at(cur_idx_);
if (remove) {
++cur_idx_;
}
return block_name;
}
VLOG(6) << "[TraversalBlockSampler] next block: empty";
return "";
}
ProbabilisticBlockSampler::ProbabilisticBlockSampler(
const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy,
utils::LinearRandomEngine::StateType rand_seed,
const std::vector<int>& weights)
: BlockSampler(all_blocks, default_remove_policy),
weights_(weights),
rand_seed_(rand_seed) {
if (weights.empty()) {
weights_.resize(all_blocks.size(), 1);
} else {
CHECK_EQ(all_blocks.size(), weights_.size());
}
remains_ = all_blocks.size();
}
std::string ProbabilisticBlockSampler::NextBlock(bool remove) {
if (remains_ == 0) {
return "";
}
int block_idx =
utils::SampleDiscreteFromDistribution<int>(weights_, &rand_seed_);
if (remove) {
weights_[block_idx] = 0;
--remains_;
}
VLOG(6) << "[ProbabilisticBlockSampler] next block: "
<< all_blocks_.at(block_idx);
return all_blocks_.at(block_idx);
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <random>
#include <vector>
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/utils/random_engine.h"
namespace cinn {
namespace auto_schedule {
class SearchState;
// Select the next block to be operated for SearchState during the search
// process
class BlockSampler {
public:
/**
* @brief Create a BlockSampler with the specific strategy name and necessary
* construct parameters.
* @param all_blocks All possible blocks to be sampled.
* @param default_remove_policy The default option to determine whether to
* delete the next block after selecting it.
* @param strategy The block sampling strategy.
* Currently, the available strategies are "traversal" and
* "probabilistic", where "traversal" means to select blocks one by one until
* all blocks are traversed, and "probabilistic" means randomly picking blocks
* according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a
* weight.
*/
static std::unique_ptr<BlockSampler> Make(
const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy = true,
const std::string& strategy = "traversal",
utils::LinearRandomEngine::StateType rand_seed = 0,
const std::vector<int>& weights = {});
// Return the name of sample strategy
virtual const char* Name() const = 0;
// Reset associated states to sample at the beginning
virtual void Reset() = 0;
// Select a block with default remove policy.
std::string NextBlock() { return NextBlock(default_remove_policy_); }
protected:
// A BlockSampler object should be created with the static function Make()
BlockSampler(const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy);
// Select a block to apply rule
// The param remove is used to determine whether to delete the next block
// after selecting it, If remove == true, it will not be sampled in the
// future.
virtual std::string NextBlock(bool remove) = 0;
// The names of all blocks
// Because the Block Expr will be changed in the search process, the name is
// saved for indexing
std::vector<std::string> all_blocks_;
// The default policy to determine whether to delete the next block after
// selecting it.
bool default_remove_policy_;
};
// Sample blocks with traversal strategy,
// witch means to select blocks one by one until all blocks are traversed.
class TraversalBlockSampler : public BlockSampler {
public:
TraversalBlockSampler(const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy)
: BlockSampler(all_blocks, default_remove_policy), cur_idx_(0) {}
const char* Name() const override { return "traversal"; }
void Reset() override { cur_idx_ = 0; }
private:
std::string NextBlock(bool remove) override;
private:
int cur_idx_;
};
// Sample blocks with probabilistic strategy,
// witch means randomly picking blocks according to the given distribution.
class ProbabilisticBlockSampler : public BlockSampler {
public:
ProbabilisticBlockSampler(const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy,
utils::LinearRandomEngine::StateType rand_seed = 0,
const std::vector<int>& weights = {});
const char* Name() const override { return "probabilistic"; }
void Reset() override {}
private:
std::string NextBlock(bool remove) override;
private:
std::vector<int> weights_;
utils::LinearRandomEngine::StateType rand_seed_;
int remains_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/block_sampler.h"
#include <gtest/gtest.h>
#include "paddle/cinn/ir/ir.h"
namespace cinn {
namespace auto_schedule {
std::vector<ir::Expr> CreateTestBlocks() {
std::vector<ir::Expr> blocks;
for (int i = 0; i < 3; ++i) {
ir::Expr block = ir::ScheduleBlock::Make(
{}, {}, {}, "block_" + std::to_string(i), ir::Expr());
blocks.push_back(ir::ScheduleBlockRealize::Make({}, block));
}
return blocks;
}
TEST(BlockSampler, Make) {
std::vector<ir::Expr> mock_blocks = CreateTestBlocks();
auto traversal_block_sampler =
BlockSampler::Make(mock_blocks, true, "traversal");
ASSERT_STREQ(traversal_block_sampler->Name(), "traversal");
auto probabilistic_block_sampler =
BlockSampler::Make(mock_blocks, true, "probabilistic");
ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic");
}
TEST(TraversalBlockSampler, NextBlock) {
std::vector<ir::Expr> blocks = CreateTestBlocks();
auto traversal_block_sampler = BlockSampler::Make(blocks, true, "traversal");
ASSERT_EQ("block_0", traversal_block_sampler->NextBlock());
ASSERT_EQ("block_1", traversal_block_sampler->NextBlock());
ASSERT_EQ("block_2", traversal_block_sampler->NextBlock());
ASSERT_EQ("", traversal_block_sampler->NextBlock());
traversal_block_sampler->Reset();
ASSERT_EQ("block_0", traversal_block_sampler->NextBlock());
traversal_block_sampler = BlockSampler::Make(blocks, false, "traversal");
ASSERT_EQ("block_0", traversal_block_sampler->NextBlock());
ASSERT_EQ("block_0", traversal_block_sampler->NextBlock());
}
TEST(ProbabilisticBlockSampler, NextBlock) {
std::vector<ir::Expr> blocks = CreateTestBlocks();
auto probabilistic_block_sampler =
BlockSampler::Make(blocks, false, "probabilistic", 0, {4, 2, 1});
std::string block_name;
for (int i = 0; i < 20; ++i) {
block_name = probabilistic_block_sampler->NextBlock();
VLOG(6) << "next block name: " << block_name;
}
probabilistic_block_sampler =
BlockSampler::Make(blocks, true, "probabilistic", 0, {4, 2, 1});
probabilistic_block_sampler->NextBlock();
probabilistic_block_sampler->NextBlock();
probabilistic_block_sampler->NextBlock();
ASSERT_EQ("", probabilistic_block_sampler->NextBlock());
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/rule_sampler.h"
#include <algorithm>
#include <random>
namespace cinn {
namespace auto_schedule {
std::unique_ptr<RuleSampler> RuleSampler::Make(
const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy,
const std::string& strategy,
utils::LinearRandomEngine::StateType rand_seed,
const std::vector<int>& weights) {
CHECK_GT(potential_rules.size(), 0) << "Empty rule list";
if (strategy == "traversal") {
return std::make_unique<TraversalRuleSampler>(potential_rules,
default_remove_policy);
} else if (strategy == "probabilistic") {
return std::make_unique<ProbabilisticRuleSampler>(
potential_rules, default_remove_policy, rand_seed, weights);
}
LOG(FATAL) << "Unimplemented strategy:" << strategy;
return nullptr;
}
AutoGenRule* TraversalRuleSampler::NextRule(bool remove) {
if (cur_idx_ < potential_rules_->size()) {
AutoGenRule* rule = potential_rules_->at(cur_idx_);
if (remove) {
++cur_idx_;
}
return rule;
}
return nullptr;
}
ProbabilisticRuleSampler::ProbabilisticRuleSampler(
const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy,
utils::LinearRandomEngine::StateType rand_seed,
const std::vector<int>& weights)
: RuleSampler(potential_rules, default_remove_policy),
weights_(weights),
rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) {
if (weights.empty()) {
weights_.resize(potential_rules.size(), 1);
} else {
CHECK_EQ(potential_rules.size(), weights_.size());
}
remains_ = potential_rules.size();
}
AutoGenRule* ProbabilisticRuleSampler::NextRule(bool remove) {
if (remains_ == 0) {
return nullptr;
}
int rule_idx =
utils::SampleDiscreteFromDistribution<int>(weights_, &rand_seed_);
if (remove) {
weights_[rule_idx] = 0;
--remains_;
}
return potential_rules_->at(rule_idx);
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <random>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/utils/random_engine.h"
namespace cinn {
namespace auto_schedule {
class SearchState;
// Select the next potential rule for the SearchState during the search process.
class RuleSampler {
public:
/**
* @brief Create a RuleSampler with the specific strategy name and necessary
* construct parameters.
* @param potential_rules All possible rules to be sampled.
* @param default_remove_policy The default option to determine whether to
* delete the next block after selecting it.
* @param strategy The rule sampling strategy.
* Currently, the available strategies are "traversal" and
* "probabilistic", where "traversal" means to select rules one by one until
* all rules are traversed, and "probabilistic" means randomly picking rules
* according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a
* weight.
*/
static std::unique_ptr<RuleSampler> Make(
const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy = true,
const std::string& strategy = "traversal",
utils::LinearRandomEngine::StateType rand_seed = 0,
const std::vector<int>& weights = {});
// Return the name of sample strategy
virtual const char* Name() const = 0;
// Reset associated states to sample at the beginning
virtual void Reset() = 0;
// Select a rule with default remove policy.
AutoGenRule* NextRule() { return NextRule(default_remove_policy_); }
protected:
// A RuleSampler object should be created with the static function Make()
RuleSampler(const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy)
: potential_rules_(&potential_rules),
default_remove_policy_(default_remove_policy) {}
// Select a rule to apply.
// The param remove is used to determine whether to delete the next rule after
// selecting it, If remove == true, it will not be sampled in the future.
virtual AutoGenRule* NextRule(bool remove) = 0;
// The pointer refers to all potential rules
const std::vector<AutoGenRule*>* potential_rules_;
// The default policy to determine whether to delete the next rule after
// selecting it.
bool default_remove_policy_;
};
// Sample rules with traversal strategy,
// witch means to select rules one by one until all rules are traversed.
class TraversalRuleSampler : public RuleSampler {
public:
TraversalRuleSampler(const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy)
: RuleSampler(potential_rules, default_remove_policy), cur_idx_(0) {}
const char* Name() const override { return "traversal"; }
void Reset() override { cur_idx_ = 0; }
private:
AutoGenRule* NextRule(bool remove) override;
private:
int cur_idx_;
};
// Sample rules with probabilistic strategy,
// which means randomly picking rules according to the given distribution.
class ProbabilisticRuleSampler : public RuleSampler {
public:
ProbabilisticRuleSampler(const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy,
utils::LinearRandomEngine::StateType rand_seed = 0,
const std::vector<int>& weights = {});
const char* Name() const override { return "probabilistic"; }
void Reset() override {}
private:
AutoGenRule* NextRule(bool remove) override;
private:
std::vector<int> weights_;
utils::LinearRandomEngine::StateType rand_seed_;
int remains_;
};
} // namespace auto_schedule
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment