Commit 992bec46 authored by “yuguo”'s avatar “yuguo”
Browse files

2.5

parent 0259837d
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/rule_sampler.h"
#include <gtest/gtest.h>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h"
namespace cinn {
namespace auto_schedule {
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
std::vector<AutoGenRule*> GenerateTestRules() {
return {new AutoUnroll(target), new SkipRule(target)};
}
TEST(RuleSampler, Make) {
std::vector<AutoGenRule*> rules = GenerateTestRules();
auto traversal_block_sampler = RuleSampler::Make(rules, true, "traversal");
ASSERT_STREQ(traversal_block_sampler->Name(), "traversal");
auto probabilistic_block_sampler =
RuleSampler::Make(rules, true, "probabilistic");
ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic");
}
TEST(TraversalRuleSampler, NextRule) {
std::vector<AutoGenRule*> rules = GenerateTestRules();
auto traversal_rule_sampler = RuleSampler::Make(rules, true, "traversal");
AutoGenRule* rule = traversal_rule_sampler->NextRule();
ASSERT_EQ("AutoUnroll", rule->GetRuleName());
rule = traversal_rule_sampler->NextRule();
ASSERT_EQ("SkipRule", rule->GetRuleName());
traversal_rule_sampler->Reset();
rule = traversal_rule_sampler->NextRule();
ASSERT_EQ("AutoUnroll", rule->GetRuleName());
traversal_rule_sampler = RuleSampler::Make(rules, false, "traversal");
rule = traversal_rule_sampler->NextRule();
ASSERT_EQ("AutoUnroll", rule->GetRuleName());
rule = traversal_rule_sampler->NextRule();
ASSERT_EQ("AutoUnroll", rule->GetRuleName());
}
TEST(ProbabilisticRuleSampler, NextRule) {
std::vector<AutoGenRule*> rules = GenerateTestRules();
auto probabilistic_rule_sampler =
RuleSampler::Make(rules, false, "probabilistic", 0, {4, 1});
AutoGenRule* rule;
for (int i = 0; i < 20; ++i) {
rule = probabilistic_rule_sampler->NextRule();
VLOG(6) << "next rule name: " << rule->GetRuleName();
}
probabilistic_rule_sampler =
RuleSampler::Make(rules, true, "probabilistic", 0, {4, 1});
probabilistic_rule_sampler->NextRule();
probabilistic_rule_sampler->NextRule();
ASSERT_EQ(nullptr, probabilistic_rule_sampler->NextRule());
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/search_space.h"
#include <glog/logging.h>
#include <cstdlib>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/cost_model/expr_cost_model.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h"
#include "paddle/cinn/auto_schedule/search_space/block_sampler.h"
#include "paddle/cinn/auto_schedule/search_space/rule_sampler.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn {
namespace auto_schedule {
SearchSpace::SearchSpace(const TuneTask& tune_task,
utils::LinearRandomEngine::StateType rand_seed)
: tune_task_(tune_task),
rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) {
const auto& target = tune_task_.target;
// initialize a set of rules and they are commonly used by all states
// TODO(zhhsplendid): pass correct output names to AutoInline
// sketch_rules_.emplace_back(new AutoInline(target,
// tune_task_.output_names));
sketch_rules_.emplace_back(
new MultiLevelTiling(target, MultiLevelTiling::kConfigs.at(target.arch)));
sketch_rules_.emplace_back(new AutoUnroll(target));
sketch_rules_.emplace_back(new SkipRule(target));
}
SearchState SearchSpace::GetScheduleMutate(const SearchState& state,
const ExprCostModel& cost_model) {
bool has_manual_schedule = false;
if (has_manual_schedule) {
SearchState ret = ManualScheduleMutate(state);
return ret;
}
SearchState ret = RandomScheduleMutate(state);
if (FLAGS_auto_schedule_use_cost_model) {
ret->predicted_cost =
cost_model.Predict(ret->ir_schedule.GetModule(), tune_task_.target);
}
VLOG(4) << JoinStatesDebugString(
"SearchSpace::GetScheduleMutate", {state}, /*verbose=*/VLOG_IS_ON(5));
return ret;
}
SearchState SearchSpace::ManualScheduleMutate(const SearchState& state) {
// TODO(zhhsplendid): Add manual schedule mutate
return state;
}
SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
// 1. Found the schedules which can apply on this Expr
// 2. Make a distribution on those schedules
std::map<int, int> weight_to_rule_index;
int cur_weight = 0;
SearchState ret(state);
std::vector<RuleApplyType> apply_types(ret->applicable_rules.size());
for (int idx = 0; idx != ret->applicable_rules.size(); ++idx) {
AutoGenRule* rule = ret->applicable_rules.at(idx);
RuleApplyType apply_type = rule->Init(&ret->ir_schedule);
VLOG(6) << "Evaluate rule:" << rule->GetRuleName() << "="
<< static_cast<int>(apply_type);
apply_types[idx] = apply_type;
if (apply_type != RuleApplyType::kCannotApply) {
weight_to_rule_index[cur_weight] = idx;
cur_weight += rule->NumberApplicable();
}
}
if (weight_to_rule_index.empty()) {
// No applicable rule, return the input mod_expr
VLOG(6) << "No applicable rule";
return ret;
}
// 3. Sample a schedule on the distribution
int sample_weighted_index =
utils::SampleUniformInt(0, cur_weight, &rand_seed_);
auto iter = weight_to_rule_index.upper_bound(sample_weighted_index);
--iter;
int sample_rule_index = iter->second;
CHECK_LT(sample_rule_index, ret->applicable_rules.size());
AutoGenRule* sample_rule = ret->applicable_rules.at(sample_rule_index);
VLOG(7) << "Apply rule: " << sample_rule->GetRuleName()
<< " with index=" << sample_weighted_index - iter->first;
// 4. Apply the schedule change
sample_rule->Apply(sample_weighted_index - iter->first);
// 5. Remove the rule after applying it
if (apply_types.at(sample_rule_index) != RuleApplyType::kCannotApply) {
ret->applicable_rules.erase(ret->applicable_rules.begin() +
sample_rule_index);
}
return ret;
}
std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) {
VLOG(5) << "SearchSpace::GetRandomInitialSketch with num=" << num;
ir::IRSchedule init_schedule(
ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
utils::ForkRandomState(&rand_seed_));
std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(),
sketch_rules_.end(),
std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
std::vector<SearchState> result;
while (result.size() < num) {
SearchState state(init_schedule, SearchState::NOT_INIT_COST, init_rules);
for (int i = 0; i < init_sketch_random_depth_; ++i) {
VLOG(6) << "Generating random sketch with RandomScheduleMutate at depth: "
<< i;
state = RandomScheduleMutate(state);
if (state->applicable_rules.empty()) {
break;
}
}
VLOG(5) << JoinStatesDebugString(
"SearchSpace::GetRandomInitialSketch-New_Sketch",
{state},
/*verbose=*/VLOG_IS_ON(6));
result.emplace_back(std::move(state));
}
return result;
}
std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() {
VLOG(5) << "SearchSpace::InitSketchWithRandomPrunedStrategy";
ir::IRSchedule init_schedule(
ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
utils::ForkRandomState(&rand_seed_));
auto all_blocks = init_schedule.GetAllBlocks();
auto block_sampler = BlockSampler::Make(
all_blocks, true, "probabilistic", utils::ForkRandomState(&rand_seed_));
std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(),
sketch_rules_.end() - 1,
std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
CHECK(init_rules.size() > 0) << "number of init rules cannot be 0";
SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {});
std::vector<SearchState> states_buf1{init_state}, states_buf2;
std::vector<SearchState>* p_states_cur = &states_buf1;
std::vector<SearchState>* p_states_next = &states_buf2;
int total_steps = 0, steps;
std::string block_name;
while ("" != (block_name = block_sampler->NextBlock()) &&
total_steps < init_sketch_random_depth_) {
steps = utils::SampleUniformInt(1, init_rules.size() + 1, &rand_seed_);
if (total_steps + steps > init_sketch_random_depth_) {
steps = init_sketch_random_depth_ - total_steps;
}
total_steps += steps;
p_states_next->clear();
for (const auto& state : *p_states_cur) {
auto rule_sampler =
RuleSampler::Make(init_rules,
true,
"probabilistic",
utils::ForkRandomState(&rand_seed_));
auto new_states = ApplySketchRule(
state, block_name, rule_sampler.get(), steps, false, 1);
p_states_next->insert(
p_states_next->end(), new_states.begin(), new_states.end());
}
std::swap(p_states_cur, p_states_next);
}
VLOG(5) << JoinStatesDebugString(
"SearchSpace::InitSketchWithRandomPrunedStrategy",
*p_states_cur,
/*verbose=*/VLOG_IS_ON(6));
return *p_states_cur;
}
std::vector<SearchState> SearchSpace::InitSketchWithRulePrunedStrategy() {
VLOG(5) << "SearchSpace::InitSketchWithRulePrunedStrategy";
ir::IRSchedule init_schedule(
ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
utils::ForkRandomState(&rand_seed_));
auto all_blocks = init_schedule.GetAllBlocks();
std::reverse(all_blocks.begin(), all_blocks.end());
auto block_sampler = BlockSampler::Make(all_blocks, true, "traversal");
std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(),
sketch_rules_.end() - 1,
std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
CHECK(init_rules.size() > 0) << "number of init rules cannot be 0";
SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {});
std::vector<SearchState> states_buf1{init_state}, states_buf2;
std::vector<SearchState>* p_states_cur = &states_buf1;
std::vector<SearchState>* p_states_next = &states_buf2;
std::string block_name;
while ("" != (block_name = block_sampler->NextBlock())) {
p_states_next->clear();
for (const auto& state : *p_states_cur) {
auto rule_sampler = RuleSampler::Make(init_rules, true, "traversal");
auto new_states =
ApplySketchRule(state, block_name, rule_sampler.get(), 0, true);
p_states_next->insert(
p_states_next->end(), new_states.begin(), new_states.end());
}
std::swap(p_states_cur, p_states_next);
}
VLOG(5) << JoinStatesDebugString(
"SearchSpace::InitSketchWithRulePrunedStrategy",
*p_states_cur,
/*verbose=*/VLOG_IS_ON(6));
return *p_states_cur;
}
std::vector<SearchState> SearchSpace::GenerateSketches(
int num, const std::string& strategy) {
VLOG(4) << "SearchSpace::GenerateSketches with num = " << num;
if (strategy == "random") {
return InitSketchWithRandomStrategy(num);
}
std::vector<SearchState> result;
while (result.size() < num) {
std::vector<SearchState> sketchs;
if (strategy == "rule_prune") {
sketchs = InitSketchWithRulePrunedStrategy();
} else if (strategy == "random_prune") {
sketchs = InitSketchWithRandomPrunedStrategy();
} else {
LOG(FATAL) << "Unimplemented init sketch strategy";
}
// the more rules are applied, the greater the possibility of good results,
// the more rules are applied, the more they are saved behind the queue,
// so we give priority to the results in the rear
for (auto iter = sketchs.rbegin(); iter != sketchs.rend(); ++iter) {
result.push_back(*iter);
if (result.size() == num) {
break;
}
}
}
VLOG(4) << JoinStatesDebugString(
"SearchSpace::GenerateSketches", result, /*verbose=*/VLOG_IS_ON(5));
return result;
}
std::vector<SearchState> SearchSpace::ApplySketchRule(
const SearchState& state,
const std::string& block_name,
RuleSampler* rule_sampler,
int steps,
bool prune_by_rule,
double prune_probability) {
std::list<SearchState> layer{state};
int step = 0;
AutoGenRule* rule;
// After determining a SearchState and a block, each rule has two
// possibilities: apply and not apply. In all transfer spaces, select a rule
// at each step, and collect all possible new states arrived by apply and not
// apply. This forms a tree, and we can use rule pruning or random pruning to
// reduce the number of sketches.
VLOG(6) << "Collect the states of all transfers within steps: " << steps;
while ((step++ < steps || steps == 0) && (rule = rule_sampler->NextRule())) {
VLOG(7) << "step = " << step << ", rule: " << rule->GetRuleName();
std::list<SearchState> new_states;
int id = 0;
for (std::list<SearchState>::iterator iter = layer.begin();
iter != layer.end();) {
// Some rules will reduce the number of blocks, such as AutoInline,
// so we need to check whether the SearchState still has the block.
if (!(*iter)->ir_schedule.HasBlock(block_name)) {
++iter;
continue;
}
auto type = rule->AnalyseApplyType(*iter, block_name);
VLOG(7)
<< "At SearchState " << ++id << ", apply type = "
<< static_cast<typename std::underlying_type<RuleApplyType>::type>(
type);
// if cannot apply the rule, skip it
if (type == RuleApplyType::kCannotApply) {
++iter;
continue;
}
// if can apply the rule, apply it and determine whether to prune the
// branch that do not apply
std::vector<SearchState> tmp_states =
rule->ApplyOnBlock(*iter, block_name);
new_states.insert(new_states.end(), tmp_states.begin(), tmp_states.end());
bool need_prune = false;
if (prune_by_rule) {
need_prune = (type == RuleApplyType::kApplyAndPruneOtherRules);
} else {
need_prune =
(utils::SampleUniformDouble(0, 1, &rand_seed_) < prune_probability);
}
if (need_prune) {
iter = layer.erase(iter);
} else {
++iter;
}
}
VLOG(7) << "apply on block: " << block_name << ", generate "
<< new_states.size() << " new states at step " << step;
layer.splice(layer.end(), std::move(new_states));
}
VLOG(6) << "apply on block: " << block_name << ", generate "
<< layer.size() - 1 << " more states at all";
return std::vector<SearchState>(layer.begin(), layer.end());
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/cost_model/expr_cost_model.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/rule_sampler.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
/**
* This class is an abstraction of the transformations can be applied to
* ir::Expr during auto-tuning. The transformation can be:
*
* 1. Manual defined schedule
* 2. Schedule generated by AutoGenRule
*
* TODO(zhhsplendid): de-duplication the generated ModuleExpr
*/
class SearchSpace {
public:
SearchSpace(const TuneTask& tune_task,
utils::LinearRandomEngine::StateType rand_seed = -1);
// Sketch mutate, returns the mutated ModuleExpr and estimited cost
virtual SearchState GetScheduleMutate(const SearchState& state,
const ExprCostModel& cost_model);
/**
* \brief Generate sketch as initial population of evolutionary search.
* @param num The number of sketches to generate.
* @param strategy The strategy to generate sketchs,
* Current optional strategies are "rule_prune" or "random_prune" or
* "random".
* - "rule_prune": will use rules to prune and generate sketches as
* efficiently as possible.
* - "random_prune": will use the new interface ApplySketchRules() to simulate
* the random generation of sketches, and supports the function of a rule
* returning multiple SearchStates and random pruning by probability.
* - "random": will randomly select a block and a rule to apply and repeat
* this step several times, however, each rule can only be used on one
* SearchState at most once.
* @return Generated sketchs.
*/
virtual std::vector<SearchState> GenerateSketches(
int num, const std::string& strategy);
private:
// TODO(zhhsplendid): mutate by manual schedule.
SearchState ManualScheduleMutate(const SearchState& state);
// mutate by sketch rules randomly
SearchState RandomScheduleMutate(const SearchState& state);
// Generate num sketchs, each with several rounds of SketchMutate
std::vector<SearchState> InitSketchWithRandomStrategy(int num);
// Generate sketch pruned randomly as initial population of evolutionary
// search
std::vector<SearchState> InitSketchWithRandomPrunedStrategy();
// Generate sketch pruned by rules as initial population of evolutionary
// search
std::vector<SearchState> InitSketchWithRulePrunedStrategy();
/**
* @brief Collect the new states that may be transferred to after applying
* several rules on a block from a certain state.
* @param state Starting point of state transition.
* @param block_name Name of the block to apply the rules to.
* @param rule_sampler Sampler that samples the new rule to apply on the
* block.
* @param steps Number of steps to apply the rule.
* @param prune_by_rule If true, prune the state transition tree by rule,
* otherwise prune randomly.
* @param prune_probability Pruning probability of random pruning.
*/
std::vector<SearchState> ApplySketchRule(const SearchState& state,
const std::string& block_name,
RuleSampler* rule_sampler,
int steps,
bool prune_by_rule,
double prune_probability = 1);
private:
const TuneTask& tune_task_;
int init_sketch_random_depth_ = 6;
// supported AutoGenRules, every task holds a set
std::vector<std::unique_ptr<AutoGenRule>> sketch_rules_;
utils::LinearRandomEngine::StateType rand_seed_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/search_space.h"
#include <gtest/gtest.h>
namespace cinn {
namespace auto_schedule {} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/utils/functional.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace auto_schedule {
SearchState::SearchState(ir::IRSchedule ir_sch,
float cost,
const std::vector<AutoGenRule*>& rules)
: common::Shared<_SearchState_>(common::make_shared<_SearchState_>()) {
auto* state = get();
state->ir_schedule = std::move(ir_sch);
state->applicable_rules = rules;
state->predicted_cost = cost;
}
SearchState SearchState::Copy() const {
return SearchState((*this)->ir_schedule, (*this)->predicted_cost, {});
}
std::string _SearchState_::DebugString() const {
const auto& exprs = ir_schedule.GetModule().GetExprs();
std::stringstream module_stream;
for (auto i = 0; i < exprs.size(); ++i) {
module_stream << "Expr " << i << " {\n"
<< exprs.at(i) << "\n} // end Expr";
}
const char* fmt_str = R"ROC(
ModuleExpr {
%s
} // end ModuleExpr
ScheduleDesc {
%s
} // end ScheduleDesc
predicted_cost: %f)ROC";
return utils::StringFormat(fmt_str,
module_stream.str().c_str(),
ir_schedule.GetTraceDesc().DebugString().c_str(),
predicted_cost);
}
bool operator<(const SearchState& left, const SearchState& right) {
return left->predicted_cost < right->predicted_cost;
}
// Visit every node by expanding all of their fields in dfs order
class DfsWithExprsFields : public ir::IRVisitorRequireReImpl<void> {
protected:
#define __m(t__) \
void Visit(const ir::t__* x) override { \
for (auto* n : x->expr_fields()) { \
if (n->defined()) { \
Visit(n); \
} \
} \
}
NODETY_FORALL(__m)
#undef __m
void Visit(const Expr* expr) override { IRVisitorRequireReImpl::Visit(expr); }
};
// Generate a reduce hash of a AST tree by combining hash of each AST node
class IrNodesStructuralHash : public DfsWithExprsFields {
public:
explicit IrNodesStructuralHash(size_t init_key) : hash_key_(init_key) {}
size_t operator()(const Expr* expr) {
Visit(expr);
return hash_key_;
}
void Visit(const Expr* expr) override {
static decltype(ir::kIrNodeTyReprs) Node2Name = ir::kIrNodeTyReprs;
if (!expr->defined()) return;
auto type_code = static_cast<IrNodeTyUnderlyingType>(expr->node_type());
hash_key_ = utils::HashCombine(hash_key_, type_code);
DfsWithExprsFields::Visit(expr);
}
private:
void Visit(const ir::_Tensor_* x) override {
for (auto& e : x->shape) {
Visit(&e);
}
DfsWithExprsFields::Visit(x->buffer.As<ir::_Buffer_>());
}
using IrNodeTyUnderlyingType = std::underlying_type<ir::IrNodeTy>::type;
size_t hash_key_;
};
size_t SearchStateHash::operator()(const SearchState& s) const {
size_t hash_key = 0;
const auto& exprs = s->ir_schedule.GetModule().GetExprs();
for (auto&& expr : exprs) {
hash_key = IrNodesStructuralHash(hash_key)(&expr);
}
return hash_key;
}
bool SearchStateEqual::operator()(const SearchState& lhs,
const SearchState& rhs) const {
const auto& lhs_exprs = lhs->ir_schedule.GetModule().GetExprs();
const auto& rhs_exprs = rhs->ir_schedule.GetModule().GetExprs();
// compare exprs size firstly
if (lhs_exprs.size() != rhs_exprs.size()) return false;
// compare every expr one by one with ir::IrEqualVisitor
for (int i = 0; i < lhs_exprs.size(); ++i) {
ir::IrEqualVisitor compartor(
/*allow_name_suffix_diff=*/true); // ignore suffix difference in name
if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false;
}
return true;
}
std::string JoinStatesDebugString(const std::string& title,
const std::vector<SearchState>& states,
bool verbose) {
std::stringstream ss;
ss << title << " states size:" << states.size() << "\n";
SearchStateHash state_hasher;
for (size_t i = 0; i < states.size(); ++i) {
uint64_t hash_key = state_hasher(states[i]);
if (verbose) {
ss << "\tState-" << i << " hash:" << hash_key << "\t content:------>"
<< states[i]->DebugString() << "\n<------";
} else {
ss << "\tState-" << i << " hash:" << hash_key << "\n";
}
}
return std::move(*ss.rdbuf()).str();
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <limits>
#include <vector>
#include "paddle/cinn/common/object.h"
#include "paddle/cinn/common/shared.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_compare.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace cinn {
namespace auto_schedule {
struct _SearchState_;
class AutoGenRule;
//! Shared Wrapper for _SearchState_
class SearchState : public common::Shared<_SearchState_> {
public:
SearchState() = default;
// create a new SearchState
explicit SearchState(ir::IRSchedule ir_sch,
float cost = NOT_INIT_COST,
const std::vector<AutoGenRule*>& rules = {});
// Constant standing for a cost not being initialized
static constexpr float NOT_INIT_COST = std::numeric_limits<float>::max();
// compare function for two states
friend bool operator<(const SearchState& left, const SearchState& right);
// Deep copy a SearchState
SearchState Copy() const;
};
//! Class to store immediate states during search
struct _SearchState_ : public common::Object {
// IRSchedule contains ir::ModuleExpr and trace scheduling process
ir::IRSchedule ir_schedule;
// Cost model predicted cost
float predicted_cost;
// The rules that can be applied to the IRSchedule at this state.
std::vector<AutoGenRule*> applicable_rules;
// return detail string of content for debug;
std::string DebugString() const;
const char* type_info() const override { return __type_info__; }
static constexpr char* __type_info__ = "auto_schedule_state";
};
// SearchStateHash hash functor that visits every AST node and combine their
// hash of node_type in dfs order
struct SearchStateHash {
size_t operator()(const SearchState& s) const;
};
// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST
// struct and fields
struct SearchStateEqual {
bool operator()(const SearchState& lhs, const SearchState& rhs) const;
};
/*!
* \brief concatenate debug strings of all states with additional info
* \param title head of the result string
* \param states SearchState array to be debugged
* \param verbose whether to enable more verbose debug info
* \return the concatenated debug string
*/
std::string JoinStatesDebugString(const std::string& title,
const std::vector<SearchState>& states,
bool verbose = false);
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/context.h"
namespace cinn {
namespace auto_schedule {
TEST(TestSearchState, SearchStateHash_Equal) {
Target target = common::DefaultHostTarget();
ir::Expr M(32);
ir::Expr N(32);
lang::Placeholder<float> A("A", {M, N});
ir::Tensor B = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i, j) + ir::Expr(2.f); }, "B");
ir::Tensor C = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C");
cinn::common::Context::Global().ResetNameId();
auto a_plus_const_funcs_1 = lang::LowerVec("A_plus_const",
poly::CreateStages({A, B}),
{A, B},
{},
{},
nullptr,
target,
true);
cinn::common::Context::Global().ResetNameId();
auto a_plus_const_funcs_2 = lang::LowerVec("A_plus_const",
poly::CreateStages({A, B}),
{A, B},
{},
{},
nullptr,
target,
true);
cinn::common::Context::Global().ResetNameId();
auto a_plus_b_funcs = lang::LowerVec("A_plus_B",
poly::CreateStages({A, C}),
{A, C},
{},
{},
nullptr,
target,
true);
std::string a_plus_const_funcs_1_str = R"ROC(function A_plus_const (_A, _B)
{
ScheduleBlock(root)
{
serial for (i, 0, 32)
{
serial for (j, 0, 32)
{
ScheduleBlock(B)
{
i0, i1 = axis.bind(i, j)
B[i0, i1] = (A[i0, i1] + 2.00000000f)
}
}
}
}
})ROC";
std::string a_plus_const_funcs_2_str = R"ROC(function A_plus_const (_A, _B)
{
ScheduleBlock(root)
{
serial for (i, 0, 32)
{
serial for (j, 0, 32)
{
ScheduleBlock(B)
{
i0, i1 = axis.bind(i, j)
B[i0, i1] = (A[i0, i1] + 2.00000000f)
}
}
}
}
})ROC";
std::string a_plus_b_funcs_str = R"ROC(function A_plus_B (_A, _C)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
serial for (j, 0, 32)
{
ScheduleBlock(B)
{
i0, i1 = axis.bind(i, j)
B[i0, i1] = (A[i0, i1] + 2.00000000f)
}
}
}
serial for (i, 0, 32)
{
serial for (j, 0, 32)
{
ScheduleBlock(C)
{
i0_0, i1_0 = axis.bind(i, j)
C[i0_0, i1_0] = (A[i0_0, i1_0] + B[i0_0, i1_0])
}
}
}
}
}
})ROC";
ASSERT_EQ(a_plus_const_funcs_1.size(), 1);
EXPECT_EQ(a_plus_const_funcs_1_str,
utils::GetStreamCnt(a_plus_const_funcs_1.front()));
ASSERT_EQ(a_plus_const_funcs_2.size(), 1);
EXPECT_EQ(a_plus_const_funcs_2_str,
utils::GetStreamCnt(a_plus_const_funcs_2.front()));
ASSERT_EQ(a_plus_b_funcs.size(), 1);
EXPECT_EQ(a_plus_b_funcs_str, utils::GetStreamCnt(a_plus_b_funcs.front()));
SearchState a_plus_const_state1(
ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_1.front()->body})));
SearchState a_plus_const_state2(
ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_2.front()->body})));
SearchState a_plus_b_state(
ir::IRSchedule(ir::ModuleExpr({a_plus_b_funcs.front()->body})));
SearchStateHash hash_functor;
SearchStateEqual equal_functor;
ASSERT_EQ(hash_functor(a_plus_const_state1),
hash_functor(a_plus_const_state2));
ASSERT_TRUE(equal_functor(a_plus_const_state1, a_plus_const_state2));
ASSERT_NE(hash_functor(a_plus_const_state1), hash_functor(a_plus_b_state));
ASSERT_FALSE(equal_functor(a_plus_const_state1, a_plus_b_state));
}
} // namespace auto_schedule
} // namespace cinn
add_subdirectory(mutate_rule)
core_gather_headers()
gather_srcs(cinnapi_src SRCS evolutionary_search.cc)
cinn_cc_test(test_evolutionary_search SRCS evolutionary_search_test.cc DEPS
cinncore test_program_builder)
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h"
#include <glog/logging.h>
#include <algorithm>
#include <cstdlib>
#include <functional>
#include <limits>
#include <memory>
#include <utility>
#include "paddle/cinn/auto_schedule/database/database.h"
#include "paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h"
#include "paddle/cinn/auto_schedule/search_space/search_space.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/utils/multi_threading.h"
#include "paddle/cinn/utils/sized_multi_set.h"
#include "paddle/cinn/utils/string.h"
DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn {
namespace auto_schedule {
EvolutionarySearch::EvolutionarySearch(
const TuneTask& tune_task,
const ExprCostModel& cost_model,
Database* database,
utils::LinearRandomEngine::StateType rand_seed,
const std::vector<std::tuple<std::string, double>>& mutate_rules)
: tune_task_(tune_task),
cost_model_(cost_model),
database_(database),
rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)),
mutators_(mutate_rules) {
search_space_ = std::make_unique<SearchSpace>(
tune_task, utils::ForkRandomState(&rand_seed_));
if (mutators_.empty()) {
mutators_.push_back(std::make_tuple("mutate_tile_size", 1.0));
}
double accum_weight = 0.0;
for (const auto& mutator : mutators_) {
if (std::get<1>(mutator) > 0) {
accum_weight += std::get<1>(mutator);
weighted_mutators_.insert(
std::make_pair(accum_weight, MutateRule::Make(std::get<0>(mutator))));
}
}
post_schedule_rules_.emplace_back(new CooperativeProcess);
}
EvolutionarySearch::~EvolutionarySearch() {}
SearchState EvolutionarySearch::SearchModuleExpr(const TuningOptions& options) {
return SearchModuleExprBests(options)[0];
}
std::vector<SearchState> EvolutionarySearch::SearchModuleExprBests(
const TuningOptions& options) {
VLOG(4) << "start SearchModuleExprBests with initial statistics: "
"visited_candidates size="
<< visited_candidates_.size();
std::vector<SearchState> init_population;
std::vector<SearchState> topk_from_database =
GetTopKCandidatesFromDatabase(options.evolution_pick_database_topk);
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::GetTopKCandidatesFromDatabase",
topk_from_database,
/*verbose=*/VLOG_IS_ON(5));
int init_num =
options.evolution_init_population_num - topk_from_database.size();
std::vector<SearchState> init_sketch = InitSketch(init_num, "rule_prune");
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::InitSketch", init_sketch, /*verbose=*/VLOG_IS_ON(5));
init_population.insert(init_population.end(),
topk_from_database.begin(),
topk_from_database.end());
init_population.insert(
init_population.end(), init_sketch.begin(), init_sketch.end());
std::vector<SearchState> picked_bests =
Evolve(init_population,
options.evolution_cross_over_num,
options.num_samples_per_iteration);
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve", picked_bests, /*verbose=*/VLOG_IS_ON(5));
return picked_bests;
}
std::vector<SearchState> EvolutionarySearch::SearchModuleExprEpsGreedy(
const TuningOptions& options) {
std::vector<SearchState> picked_bests = SearchModuleExprBests(options);
int random_num = options.evolution_init_population_num -
options.evolution_pick_database_topk;
auto results =
PickNextGenerationEpsGreedy(picked_bests,
InitSketch(random_num, "random_prune"),
options.num_samples_per_iteration,
options.evolution_eps_greedy);
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::PickNextGenerationEpsGreedy",
results,
/*verbose=*/VLOG_IS_ON(5));
return results;
}
std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(
int topk) {
std::vector<SearchState> results;
const auto& task_key = tune_task_.serialized_key;
auto records = database_->GetTopK(task_key, topk);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto&& record : records) {
ir::IRSchedule ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(&rand_seed_));
ir::ScheduleDesc::ReplayWithProto(record.trace, &ir_sch);
results.emplace_back(SearchState(std::move(ir_sch), record.predicted_cost));
}
return results;
}
void ApplyPostScheduleRules(
ir::IRSchedule* schedule,
const std::vector<std::unique_ptr<PostScheduleRule>>& post_schedule_rules) {
schedule->TagPostSchedule();
for (const auto& post_rule : post_schedule_rules) {
post_rule->Apply(schedule);
}
}
std::vector<SearchState> EvolutionarySearch::InitSketch(
int num, const std::string& strategy) {
VLOG(4) << "InitSketch with num:" << num << ", strategy: " << strategy;
std::vector<SearchState> states =
search_space_->GenerateSketches(num, strategy);
auto post_schedule_fn = [this, &states](int index) {
ApplyPostScheduleRules(&states[index]->ir_schedule, post_schedule_rules_);
};
utils::parallel_run(post_schedule_fn,
utils::SequenceDispatcher(0, states.size()),
states.size());
return states;
}
SearchState EvolutionarySearch::CrossOver(const SearchState& state1,
const SearchState& state2) {
// TODO(CtfGo): tracing CrossOver with IRSchedule
std::vector<ir::Expr> cross_over_exprs;
std::vector<ir::Expr> father_exprs =
state1->ir_schedule.GetModule().GetExprs();
std::vector<ir::Expr> mother_exprs =
state2->ir_schedule.GetModule().GetExprs();
CHECK_EQ(father_exprs.size(), mother_exprs.size())
<< "CrossOver ModuleExpr in EvolutionarySearch must have same number of "
"AST";
for (size_t i = 0; i < father_exprs.size(); ++i) {
if (utils::SampleUniformInt(0, 2, &rand_seed_) == 0) {
cross_over_exprs.push_back(optim::IRCopy(father_exprs[i]));
} else {
cross_over_exprs.push_back(optim::IRCopy(mother_exprs[i]));
}
}
auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs),
utils::ForkRandomState(&rand_seed_)));
if (FLAGS_auto_schedule_use_cost_model) {
res->predicted_cost =
cost_model_.Predict(res->ir_schedule.GetModule(), tune_task_.target);
}
VLOG(5) << JoinStatesDebugString("EvolutionarySearch::CrossOver",
{state1, state2, res},
/*verbose=*/VLOG_IS_ON(6));
return res;
}
SearchState EvolutionarySearch::Mutate(
const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed) {
CHECK_GT(weighted_mutators_.size(), 0)
<< "There is no mutate rule can be applied.";
double accu_weight = (weighted_mutators_.rbegin())->first;
CHECK_GT(accu_weight, 0) << "The accumulate weight must be greater than 0.";
// sample a mutate rule
double sample_weight = utils::SampleUniformDouble(0, accu_weight, rand_seed);
auto sampled_iter = weighted_mutators_.upper_bound(sample_weight);
MutateRule* mutator = sampled_iter->second.get();
CHECK(mutator) << "mutator not defined";
// apply mutation on the trace of SearchState
auto trace = state->ir_schedule.GetTraceDesc();
auto new_trace = mutator->Apply(trace, rand_seed);
// replay the mutated trace on original ModuleExpr to generate a new
// ir_schedule
const auto& task_key = tune_task_.serialized_key;
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
ir::IRSchedule new_ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(rand_seed));
new_trace.Replay(&new_ir_sch, true);
ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_);
auto res = SearchState(std::move(new_ir_sch));
VLOG(5) << JoinStatesDebugString(
"EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6));
return res;
}
std::vector<SearchState> EvolutionarySearch::Evolve(
const std::vector<SearchState>& population,
int cross_over_num,
int ret_num) {
VLOG(4) << utils::StringFormat(
"Evolve with population size=%lu,cross_over_num:%lu,ret_num:%lu",
population.size(),
cross_over_num,
ret_num);
int generation_num = population.size();
if (generation_num == 0) {
return std::vector<SearchState>();
}
// init evolution
std::vector<SearchState> evolution(population);
for (SearchState& search_state : evolution) {
if (search_state->predicted_cost == SearchState::NOT_INIT_COST &&
FLAGS_auto_schedule_use_cost_model) {
search_state->predicted_cost = cost_model_.Predict(
search_state->ir_schedule.GetModule(), tune_task_.target);
}
}
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: Init evolution:",
evolution,
/*verbose=*/VLOG_IS_ON(5));
// cross over
for (int i = 0; i < cross_over_num; ++i) {
int first_rand_idx =
utils::SampleUniformInt(0, generation_num, &rand_seed_);
int second_rand_idx =
utils::SampleUniformInt(0, generation_num, &rand_seed_);
while (first_rand_idx == second_rand_idx) {
second_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_);
}
evolution.push_back(
CrossOver(population[first_rand_idx], population[second_rand_idx]));
}
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: after CrossOver evolution:",
evolution,
/*verbose=*/VLOG_IS_ON(5));
// mutate
std::vector<SearchState> mutated_individuals(evolution.size());
std::vector<utils::LinearRandomEngine::StateType> rand_seeds(
evolution.size());
for (int i = 0; i < rand_seeds.size(); ++i) {
rand_seeds[i] = utils::ForkRandomState(&rand_seed_);
}
auto mutate_fn = [this, &evolution, &mutated_individuals, &rand_seeds](
int index) {
mutated_individuals[index] = Mutate(evolution[index], &rand_seeds[index]);
};
utils::parallel_run(mutate_fn,
utils::SequenceDispatcher(0, evolution.size()),
evolution.size());
if (FLAGS_auto_schedule_use_cost_model) {
for (size_t i = 0; i < mutated_individuals.size(); ++i) {
mutated_individuals[i]->predicted_cost = cost_model_.Predict(
mutated_individuals[i]->ir_schedule.GetModule(), tune_task_.target);
}
}
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: mutated individuals:",
mutated_individuals,
/*verbose=*/VLOG_IS_ON(5));
// select top ret_num with predicted cost
utils::SizedMultiSet<SearchState> evolution_with_cost(ret_num);
for (size_t i = 0; i < evolution.size(); ++i) {
evolution_with_cost.Push(evolution[i]);
}
for (size_t i = 0; i < mutated_individuals.size(); ++i) {
evolution_with_cost.Push(mutated_individuals[i]);
}
auto selected_individuals =
evolution_with_cost.ReturnAsContainer<std::vector<SearchState>>();
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: selected individuals:",
selected_individuals,
/*verbose=*/VLOG_IS_ON(5));
return selected_individuals;
}
std::vector<SearchState> EvolutionarySearch::PickNextGenerationEpsGreedy(
const std::vector<SearchState>& picked_bests,
const std::vector<SearchState>& random_init,
int num,
float eps_greedy) {
int num_rands = num * eps_greedy;
int num_bests = num - num_rands;
std::vector<SearchState> result;
SearchState selected;
int deduplicated_cnt = 0;
int best_idx = 0;
int rand_idx = 0;
while (result.size() < num) {
if (result.size() < num_bests && best_idx < picked_bests.size()) {
selected = picked_bests[best_idx];
++best_idx;
} else if (rand_idx < random_init.size()) {
selected = random_init[rand_idx];
++rand_idx;
} else if (best_idx < picked_bests.size()) {
selected = picked_bests[best_idx];
++best_idx;
} else {
break;
}
if (!visited_candidates_.count(selected)) { // deduplicate
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::PickNextGenerationEpsGreedy-Selected",
{selected},
/*verbose=*/VLOG_IS_ON(5));
visited_candidates_.insert(selected);
result.push_back(selected);
} else {
++deduplicated_cnt;
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::PickNextGenerationEpsGreedy-Deduplicated",
{selected},
/*verbose=*/VLOG_IS_ON(5));
}
}
VLOG(4) << utils::StringFormat(
"PickNextGenerationEpsGreedy: picked_bests size=%lu,random_init "
"size=%lu,num=%d,"
"eps_greedy=%f,deduplicated_cnt=%d,result size=%lu",
picked_bests.size(),
random_init.size(),
num,
eps_greedy,
deduplicated_cnt,
result.size());
return result;
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/cinn/auto_schedule/cost_model/expr_cost_model.h"
#include "paddle/cinn/auto_schedule/database/database.h"
#include "paddle/cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h"
#include "paddle/cinn/auto_schedule/search_space/search_space.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
/**
* Class implement the evolutionary search on ModuleExpr search space.
*/
class EvolutionarySearch {
public:
/**
* constructor with TuneTask.
*
* @param tune_task: the TuneTask this class works on. This class doesn't
* take ownership of the pointer.
*/
EvolutionarySearch(
const TuneTask& tune_task,
const ExprCostModel& cost_model,
Database* database,
utils::LinearRandomEngine::StateType rand_seed = -1,
const std::vector<std::tuple<std::string, double>>& mutate_rules = {});
/**
* Destructor
*/
~EvolutionarySearch();
/**
* Run the evolutionary search for one iteration.
*
* @return SearchState containing the best ir::ModuleExpr searched in this
* iteration
*/
SearchState SearchModuleExpr(const TuningOptions& options);
/**
* Run the evolutionary search for one iteration.
*
* @return SearchState(s) containing best ir::ModuleExpr(s) searched in this
* iteration
*/
std::vector<SearchState> SearchModuleExprBests(const TuningOptions& options);
/**
* Run the evolutionary search for one iteration, but since evolutionary
* search with cost model may not be accurate, this method picks
* "eps * total_return_size" random samples along with those best
* ir::ModuleExpr's searched in this iteration.
*
* @return SearchSpace containing those best ir::ModuleExpr's searched
* in this iteration and some random samples. There are
* "eps * total_return_size" random samples and
* "(1 - eps) * total_return_size" best searched samples.
*/
std::vector<SearchState> SearchModuleExprEpsGreedy(
const TuningOptions& options);
#ifdef CINN_WITH_TEST
/**
* Method only be called during testing. It is used to set mock search
* space.
*
* @param search_space: the mock search space, note that EvolutionarySearch
* takes the ownership.
*/
void SetSearchSpace(SearchSpace* search_space) {
search_space_.reset(search_space);
}
// Method only be called during testing, it is a wrapper of private method
// InitSketch().
std::vector<SearchState> TestInitSketch(int num,
const std::string& strategy) {
return InitSketch(num, strategy);
}
// Method only be called during testing, it is a wrapper of private method
// Evolve().
std::vector<SearchState> TestEvolve(
const std::vector<SearchState>& population,
int cross_over_num,
int ret_num) {
return Evolve(population, cross_over_num, ret_num);
}
#endif
private:
std::vector<SearchState> GetTopKCandidatesFromDatabase(int topk);
/**
* \brief Generate sketch as initial population of evolutionary search.
* @param num The number of sketches to generate.
* @param strategy The strategy to generate sketches,
* Current optional strategies are "rule_prune" or "random_prune" or
* "random".
* - "rule_prune": will use rules to prune and generate sketches as
* efficiently as possible.
* - "random_prune": will use the new interface ApplySketchRules() to simulate
* the random generation of sketches, and supports the function of a rule
* returning multiple SearchStates and random pruning by probability.
* - "random": will randomly select a block and a rule to apply and repeat
* this step several times, however, each rule can only be used on one
* SearchState at most once.
* @return Generated sketches.
*/
std::vector<SearchState> InitSketch(int num, const std::string& strategy);
SearchState Mutate(const SearchState& state,
utils::LinearRandomEngine::StateType* rand_seed);
SearchState CrossOver(const SearchState& state1, const SearchState& state2);
std::vector<SearchState> Evolve(const std::vector<SearchState>& population,
int cross_over_num,
int ret_num);
std::vector<SearchState> PickNextGenerationEpsGreedy(
const std::vector<SearchState>& population,
const std::vector<SearchState>& random_init,
int num,
float eps_greedy);
private:
std::unique_ptr<SearchSpace> search_space_;
const TuneTask& tune_task_;
const ExprCostModel& cost_model_; // not owned
Database* database_; // not owned
// used to duplicate states with the same structural IR
std::unordered_set<SearchState, SearchStateHash, SearchStateEqual>
visited_candidates_;
// mutate rule names and their weights
std::vector<std::tuple<std::string, double>> mutators_;
// mutate rules, the key is the accumulate weight of each mutate rule
std::map<double, std::unique_ptr<MutateRule>> weighted_mutators_;
// schedule rules used after mutation
std::vector<std::unique_ptr<PostScheduleRule>> post_schedule_rules_;
utils::LinearRandomEngine::StateType rand_seed_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include "paddle/cinn/auto_schedule/cost_model/expr_cost_model.h"
#include "paddle/cinn/auto_schedule/database/database.h"
#include "paddle/cinn/auto_schedule/search_space/search_space.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/auto_schedule/task/task_creator.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "test/cpp/cinn/program_builder.h"
namespace cinn {
namespace auto_schedule {
std::vector<TuneTask> CreateTasks(const frontend::Program& program,
const Target& target) {
auto graph = std::make_shared<hlir::framework::Graph>(program, target);
TaskCreator task_creator;
auto tasks = task_creator.CreateTuneTaskOpLevel(graph.get());
const auto& dtype_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto i = 0; i < tasks.size(); ++i) {
tasks[i].Initialize(shape_dict, dtype_dict, op_lowerer.get());
task_registry->Regist(tasks[i].serialized_key,
ir::ModuleExpr(tasks[i].GetLoweredFuncBodyExprs()));
}
return tasks;
}
/**
* A mock search space is only used for test. It creates integer ir::Expr from
* 0, -1, -2, ... and set the cost value same as the integer value.
*
* So evolutionary search should be able to find the minimal ModuleExpr with
* smallest ir::Expr. This file tests it.
*/
class MockSearchSpace : public SearchSpace {
public:
explicit MockSearchSpace(const TuneTask& tune_task)
: SearchSpace(tune_task) {}
int GetMinExprValue() const { return min_expr_value_; }
int GetModuleExprSize() const { return module_expr_size_; }
std::vector<SearchState> GenerateSketches(
int num, const std::string& strategy) override {
std::vector<SearchState> ret;
for (int i = 0; i < num; ++i) {
std::vector<ir::Expr> exprs;
for (int j = 0; j < module_expr_size_; ++j) {
exprs.push_back(ir::Expr(-i));
}
min_expr_value_ = -i;
ret.push_back(SearchState(ir::IRSchedule(ir::ModuleExpr(exprs))));
}
return ret;
}
private:
int module_expr_size_ = 10;
int min_expr_value_ = 0;
};
class MockCostModel : public ExprCostModel {
float Predict(const ir::ModuleExpr& sample,
const common::Target& target) const override {
float cost = 0.0f;
std::vector<ir::Expr> exprs = sample.GetExprs();
for (const ir::Expr& expr : exprs) {
if (expr.as_int32()) {
cost += static_cast<float>((expr.as_int32()));
}
}
return cost;
}
};
TEST(EvolutionarySearch, GetOneBest) {
TuneTask mock_tune_task;
mock_tune_task.serialized_key = "mock_task";
mock_tune_task.target = common::DefaultTarget();
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
task_registry->Regist(mock_tune_task.serialized_key,
ir::ModuleExpr({ir::Expr(0)}));
MockCostModel cost_model;
TuningOptions options;
Database db(2);
EvolutionarySearch evolutionary_search(mock_tune_task, cost_model, &db);
MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task);
// Ownership is transferred so don't delete mock_search_space
evolutionary_search.SetSearchSpace(mock_search_space);
SearchState best_state = evolutionary_search.SearchModuleExpr(options);
std::vector<ir::Expr> exprs = best_state->ir_schedule.GetModule().GetExprs();
EXPECT_GE(exprs.size(), 1UL);
for (const ir::Expr& e : exprs) {
EXPECT_EQ(e.as_int32(), mock_search_space->GetMinExprValue());
}
}
TEST(EvolutionarySearch, GetEpsGreedy) {
TuneTask mock_tune_task;
mock_tune_task.serialized_key = "mock_task";
mock_tune_task.target = common::DefaultTarget();
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
task_registry->Regist(mock_tune_task.serialized_key,
ir::ModuleExpr({ir::Expr(0)}));
ExprCostModel cost_model;
TuningOptions options;
Database db(2);
EvolutionarySearch evolutionary_search(mock_tune_task, cost_model, &db);
MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task);
// Ownership is transferred so don't delete mock_search_space
evolutionary_search.SetSearchSpace(mock_search_space);
std::vector<SearchState> search_states =
evolutionary_search.SearchModuleExprEpsGreedy(options);
EXPECT_GE(search_states.size(), 1UL);
size_t expr_size =
static_cast<size_t>(mock_search_space->GetModuleExprSize());
for (const SearchState& state : search_states) {
EXPECT_EQ(state->ir_schedule.GetModule().GetExprs().size(), expr_size);
}
}
TEST(EvolutionarySearch, Evolve) {
auto target = common::DefaultNVGPUTarget();
auto tasks = CreateTasks(
tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}),
target);
CHECK_EQ(tasks.size(), 1);
ExprCostModel cost_model;
std::vector<const ir::ModuleExpr*> cost_model_samples(1);
std::vector<float> cost_model_labels(1);
for (size_t i = 0; i < 2; ++i) {
ir::ModuleExpr me({ir::Expr(tasks[0].lowered_funcs[0])});
cost_model_samples[0] = &me;
cost_model_labels[0] = i + 10;
cost_model.Update(cost_model_samples, cost_model_labels, target);
}
Database db(2);
TuningOptions options;
options.evolution_pick_database_topk = 0;
EvolutionarySearch evolutionary_search(tasks[0], cost_model, &db);
int num_population = 10;
std::vector<SearchState> init_sketch =
evolutionary_search.TestInitSketch(num_population, "rule_prune");
for (int i = 0; i < num_population; ++i) {
ir::ModuleExpr me(init_sketch[i]->ir_schedule.GetModule());
cost_model_samples[0] = &me;
cost_model_labels[0] = i;
cost_model.Update(cost_model_samples, cost_model_labels, target);
}
VLOG(6) << "init sketch costs:";
for (auto s : init_sketch) {
VLOG(6) << "cost = " << s->predicted_cost;
}
std::vector<SearchState>*population_pre_ptr = &init_sketch,
*population_next_ptr;
std::vector<SearchState> population;
for (int i = 0; i < 10; ++i) {
population = evolutionary_search.TestEvolve(
*population_pre_ptr, /*cross_over_num*/ 0, /*ret_num*/ 10);
population_next_ptr = &population;
VLOG(6) << "population[" << i + 1 << "] costs:";
double total_cost_pre = 0.0, total_cost_next = 0.0;
for (auto s : *population_pre_ptr) {
total_cost_pre += s->predicted_cost;
}
for (auto s : *population_next_ptr) {
total_cost_next += s->predicted_cost;
VLOG(6) << "cost = " << s->predicted_cost;
}
VLOG(6) << "total_cost_next = " << total_cost_next;
CHECK_LE(total_cost_next, total_cost_pre);
std::swap(population_pre_ptr, population_next_ptr);
}
}
} // namespace auto_schedule
} // namespace cinn
core_gather_headers()
gather_srcs(cinnapi_src SRCS mutate_rule.cc mutate_tile_size.cc)
cinn_cc_test(test_mutate_tile_size SRCS mutate_tile_size_test.cc DEPS cinncore)
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h"
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h"
namespace cinn {
namespace auto_schedule {
std::unique_ptr<MutateRule> MutateRule::Make(const std::string& name) {
if (name == "mutate_tile_size") {
return std::make_unique<MutateTileSize>();
} else {
LOG(FATAL) << "MutateRule " << name << " is not supported.";
}
return nullptr;
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/schedule/schedule_desc.h"
#include "paddle/cinn/utils/random_engine.h"
namespace cinn {
namespace auto_schedule {
/**
* Base class for rules of mutate,
* is used for mutating the trace(ScheduleDesc) to explore the search space.
*/
class MutateRule {
public:
MutateRule() = default;
/**
* @brief Apply the mutate rule to the given trace.
* @param trace The given trace for mutation.
* @param rand_seed The random seed for mutation.
* @return The mutated trace.
*/
virtual ir::ScheduleDesc Apply(
const ir::ScheduleDesc& trace,
utils::LinearRandomEngine::StateType* rand_seed) = 0;
/**
* @brief Create a MutateRule with name.
* @param name The name of mutate rule, consisting of lowercase letters and
* underscores
* @return The created MutateRule.
*/
static std::unique_ptr<MutateRule> Make(const std::string& name);
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h"
namespace cinn {
namespace auto_schedule {
using ::cinn::ir::ScheduleDesc;
using ::cinn::utils::LinearRandomEngine;
using SampledTile = std::tuple<ScheduleDesc::Step, std::vector<int>, int>;
static std::vector<int> Factorize(int n) {
std::vector<int> res;
for (int i = 1; i * i <= n; ++i) {
if (n % i == 0) {
res.push_back(i);
if (i * i != n) {
res.push_back(n / i);
}
}
}
std::sort(res.begin(), res.end());
return res;
}
std::vector<SampledTile> FindSampledTiles(const ScheduleDesc& trace) {
std::vector<SampledTile> tiles;
int step_idx = 0;
for (auto&& step : trace.Steps()) {
if (step.type == "TagPostSchedule") {
break;
}
if (step.type == "SamplePerfectTile") {
std::vector<int> tile_factors =
absl::get<std::vector<int>>(step.attrs.at("decision"));
CHECK(tile_factors.size() >= 2)
<< "factors size must be greater equal than 2, which is "
<< tile_factors.size();
tiles.push_back(std::make_tuple(step, tile_factors, step_idx));
}
++step_idx;
}
return tiles;
}
ScheduleDesc DoMutateTileSize(const ScheduleDesc& trace,
const SampledTile& tile,
LinearRandomEngine::StateType* rand_seed) {
ScheduleDesc::Step step = std::get<0>(tile);
std::vector<int> tile_factors = std::get<1>(tile);
int split_size = tile_factors.size();
// Step 1. Choose 2 loops with index: 'loop_x' and 'loop_y'
int loop_x, loop_y;
bool all_one_factors = true;
for (int t : tile_factors) {
if (t != 1) {
all_one_factors = false;
break;
}
}
if (all_one_factors) {
VLOG(6) << "Factors are all 1, unable to mutate, return the original trace";
return trace;
}
while (true) {
VLOG(6) << "while (true) loop in DoMutateTileSize";
loop_x = utils::SampleUniformInt(0, split_size, rand_seed);
if (tile_factors.at(loop_x) <= 1) {
continue;
}
loop_y = utils::SampleUniformInt(0, split_size - 1, rand_seed);
if (loop_y >= loop_x) {
++loop_y;
}
std::vector<int> optional_factors = Factorize(tile_factors.at(loop_x));
// Step 2. Choose the divisor for mutate.
int divisor;
if (loop_y == split_size - 1) {
int max_innermost_factor =
absl::get<int>(step.attrs.at("max_innermost_factor"));
int max_optional_factor_idx = optional_factors.size() - 1;
for (; max_optional_factor_idx > 0; --max_optional_factor_idx) {
if (optional_factors.at(max_optional_factor_idx) *
tile_factors.at(loop_y) <=
max_innermost_factor) {
break;
}
}
if (max_optional_factor_idx == 0) {
if (split_size <= 2) {
VLOG(6) << "Unable to mutate, return the original trace";
return trace;
}
continue;
}
divisor = optional_factors.at(
utils::SampleUniformInt(1, max_optional_factor_idx + 1, rand_seed));
} else {
divisor = optional_factors.at(
utils::SampleUniformInt(1, optional_factors.size(), rand_seed));
}
// Step 3. Determine the new tile value
VLOG(6) << "DoMutateTileSize: divisor = " << divisor
<< ", before mutate: \n"
<< "factors[" << loop_x << "] = " << tile_factors[loop_x]
<< ", factors[" << loop_y << "] = " << tile_factors[loop_y];
tile_factors[loop_x] /= divisor;
tile_factors[loop_y] *= divisor;
VLOG(6) << "after mutate: \n"
<< "factors[" << loop_x << "] = " << tile_factors[loop_x]
<< ", factors[" << loop_y << "] = " << tile_factors[loop_y];
// Step 4. Create a new step with new tile values and return the new trace
int step_idx = std::get<2>(tile);
return trace.ForkAndUpdate(step_idx, tile_factors, true);
}
}
ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace,
LinearRandomEngine::StateType* rand_seed) {
VLOG(6) << "Start applying MutateTileSize, old trace: \n"
<< trace.DebugString();
std::vector<ScheduleDesc::Step> sample_tile_steps;
std::vector<std::vector<int>> sample_tile_data;
auto sampled_tiles = FindSampledTiles(trace);
if (sampled_tiles.size() == 0) {
VLOG(6) << "MutateTileSize failed, try other mutate rules.";
return trace;
}
int sample_step_idx =
utils::SampleUniformInt(0, sampled_tiles.size(), rand_seed);
auto new_trace =
DoMutateTileSize(trace, sampled_tiles.at(sample_step_idx), rand_seed);
VLOG(6) << "End applying MutateTileSize, new trace: \n"
<< new_trace.DebugString();
return new_trace;
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h"
namespace cinn {
namespace auto_schedule {
/**
* The rule to mutate tile size, witch will modify the factors of the Split
* primitive.
*/
class MutateTileSize : public MutateRule {
public:
MutateTileSize() = default;
ir::ScheduleDesc Apply(
const ir::ScheduleDesc& trace,
utils::LinearRandomEngine::StateType* rand_seed) override;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
TEST(MutateTileSize, Basic) {
srand(0);
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
const int kSize = 32;
Expr M(kSize);
Expr N(kSize);
Expr K(kSize);
Placeholder<float> A("A", {M, K});
Placeholder<float> B("B", {K, N});
Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = Compute(
{M, N},
[&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = CreateStages({A, B, C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMutateTileSize_Basic",
stages,
{A, B, C},
{},
{},
nullptr,
target,
true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Original Expr: ";
VLOG(6) << ast_expr;
ir::ModuleExpr module_expr({ast_expr});
// We need to fix the seed as a constant to ensure that the result can be
// repeated.
utils::LinearRandomEngine::StateType rand_seed = 123;
ir::IRSchedule ir_schedule(module_expr, rand_seed);
ir::IRSchedule new_ir_schedule(ir_schedule);
// apply schedule
auto loops = ir_schedule.GetLoops("C");
auto factors = ir_schedule.SamplePerfectTile(loops[0], 2, kSize);
auto splited = ir_schedule.Split(loops[0], factors);
// apply mutate
MutateTileSize mutator;
ir::ScheduleDesc sch_desc =
mutator.Apply(ir_schedule.GetTraceDesc(), &rand_seed);
sch_desc.Replay(&new_ir_schedule, true);
VLOG(6) << "Expr before mutate tile size: \n"
<< ir_schedule.GetModule().GetExprs()[0];
VLOG(6) << "Expr after mutate tile size: \n"
<< new_ir_schedule.GetModule().GetExprs()[0];
std::string target_new_ir = R"ROC({
ScheduleBlock(root)
{
serial for (i_1, 0, 2)
{
serial for (i_2, 0, 16)
{
serial for (j, 0, 32)
{
ScheduleBlock(C__reduce_init)
{
i0, i1 = axis.bind(((16 * i_1) + i_2), j)
C__reduce_init[i0, i1] = 0.00000000f
}
serial for (reduce_axis_k, 0, 32)
{
ScheduleBlock(C)
{
i0_0, i1_0, i2 = axis.bind(((16 * i_1) + i_2), j, reduce_axis_k)
C[i0_0, i1_0] = (C[i0_0, i1_0] + (A[i0_0, i2] * B[i2, i1_0]))
}
}
}
}
}
}
})ROC";
auto get_ir_str = [](const ir::IRSchedule* ir_sch) -> std::string {
std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ss;
ss << exprs[0];
return ss.str();
};
ASSERT_EQ(get_ir_str(&new_ir_schedule), target_new_ir);
std::vector<int> last_tile_factors = {2, 16};
for (int i = 0; i < 10; ++i) {
sch_desc = mutator.Apply(sch_desc, &rand_seed);
for (auto&& step : sch_desc.Steps()) {
if (step.type == "SamplePerfectTile") {
std::vector<int> tile_factors =
absl::get<std::vector<int>>(step.attrs.at("decision"));
ASSERT_EQ(tile_factors.size(), last_tile_factors.size());
ASSERT_NE(tile_factors[0], last_tile_factors[0]);
ASSERT_NE(tile_factors[1], last_tile_factors[1]);
ASSERT_EQ(tile_factors[0] * tile_factors[1], kSize);
last_tile_factors = tile_factors;
}
}
}
}
} // namespace auto_schedule
} // namespace cinn
core_gather_headers()
gather_srcs(cinnapi_src SRCS task_creator.cc task_optimizer.cc tune_task.cc)
gather_srcs(cinnapi_src SRCS task_creator.cc task_optimizer.cc)
cinn_cc_test(test_task_creator SRCS task_creator_test.cc DEPS cinncore)
cinn_cc_test(test_tune_task SRCS tune_task_test.cc DEPS cinncore)
cinn_cc_test(test_task_registry SRCS task_registry_test.cc DEPS cinncore)
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/task/task_creator.h"
#include <glog/logging.h>
#include <memory>
#include <tuple>
#include <vector>
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/pass.h"
namespace cinn {
namespace auto_schedule {
using ::cinn::common::GraphEdge;
using ::cinn::common::GraphNode;
using ::cinn::hlir::framework::Graph;
using ::cinn::hlir::framework::Node;
using ::cinn::hlir::framework::NodeData;
std::vector<TuneTask> TaskCreator::CreateTuneTaskOpLevel(Graph* graph) {
std::vector<TuneTask> ret_tasks;
const std::vector<std::shared_ptr<Graph::Group>>* groups =
&graph->fusion_groups;
std::vector<std::shared_ptr<Graph::Group>> non_fused_groups;
// The input graph doesn't run Op Fusion
if (graph->fusion_groups.empty()) {
hlir::framework::ApplyPasses(graph, {"BuildNonFusedGroupsPass"});
groups = &graph->fusion_groups;
}
VLOG(3) << "Graph groups size:" << groups->size();
for (const auto& sub_graph : *groups) {
ret_tasks.emplace_back(TuneTask());
ret_tasks.back().subgraph = sub_graph;
ret_tasks.back().target = graph->target_;
}
return ret_tasks;
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/framework/graph.h"
namespace cinn {
namespace auto_schedule {
/**
* Class to create auto tune task.
*/
class TaskCreator {
public:
std::vector<TuneTask> CreateTuneTaskOpLevel(hlir::framework::Graph* graph);
};
} // namespace auto_schedule
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment