Commit 992bec46 authored by “yuguo”'s avatar “yuguo”
Browse files

2.5

parent 0259837d
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/expr_cost_model.h"
#include <glog/logging.h>
#include <atomic>
#include <vector>
#include "paddle/cinn/auto_schedule/cost_model/feature.h"
#include "paddle/cinn/auto_schedule/cost_model/feature_extractor.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
float ExprCostModel::Predict(const ir::ModuleExpr& sample,
const common::Target& target) const {
if (trained_times_.load() == 0) {
return SearchState::NOT_INIT_COST;
}
FeatureExtractor extractor;
Feature feature = extractor.Extract(sample, target);
std::vector<float> feature_numbers = feature.ToFixedSizeVector();
std::vector<float> pred = XgbCostModel::Predict({feature_numbers});
return pred[0];
}
void ExprCostModel::Train(const std::vector<const ir::ModuleExpr*>& samples,
const std::vector<float>& labels,
const common::Target& target) {
trained_times_.store(1);
size_t total_size = samples.size();
CHECK_EQ(total_size, labels.size())
<< "Samples must have same size as labels";
std::vector<std::vector<float>> train_feature_numbers(total_size);
FeatureExtractor extractor;
for (size_t i = 0; i < total_size; ++i) {
CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr";
Feature feature = extractor.Extract(*samples[i], target);
train_feature_numbers[i] = feature.ToFixedSizeVector();
}
XgbCostModel::Train(train_feature_numbers, labels);
}
void ExprCostModel::Update(const std::vector<const ir::ModuleExpr*>& samples,
const std::vector<float>& labels,
const common::Target& target) {
++trained_times_;
size_t total_size = samples.size();
CHECK_EQ(total_size, labels.size())
<< "Samples must have same size as labels";
std::vector<std::vector<float>> train_feature_numbers(total_size);
FeatureExtractor extractor;
for (size_t i = 0; i < total_size; ++i) {
CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr";
Feature feature = extractor.Extract(*samples[i], target);
train_feature_numbers[i] = feature.ToFixedSizeVector();
}
XgbCostModel::Update(train_feature_numbers, labels);
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <vector>
#include "paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
/**
* A C++ cost model which trains and predicts on ir::Expr
*
*/
class ExprCostModel : public XgbCostModel {
public:
virtual float Predict(const ir::ModuleExpr& sample,
const common::Target& target) const;
void Train(const std::vector<const ir::ModuleExpr*>& samples,
const std::vector<float>& labels,
const common::Target& target);
void Update(const std::vector<const ir::ModuleExpr*>& samples,
const std::vector<float>& labels,
const common::Target& target);
private:
std::atomic<int> trained_times_{0};
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/feature.h"
#include <glog/logging.h>
#include <vector>
#include "paddle/cinn/common/target.h"
namespace cinn {
namespace auto_schedule {
Feature::Feature()
: target_(common::UnkTarget()),
stack_encoded_feature_(1), // initialize a LoopBlockFeature as root block
current_loop_block_index_(0),
parent_indices_(1, -1) {}
Feature::Feature(const common::Target& target)
: target_(target),
stack_encoded_feature_(1), // initialize a LoopBlockFeature as root block
current_loop_block_index_(0),
parent_indices_(1, -1) {}
std::vector<float> Feature::ToFixedSizeVector() {
std::vector<float> ret(LoopBlockFeature::kTotalSize + 1,
0); // LoopBlockFeature::kTotalSize plus 1 for target
if (target_ == common::DefaultNVGPUTarget()) {
ret[0] = 1;
} // else 0 for other cases
// loop[i] feature count should multiply iter_multi_num[i]
std::vector<int> iter_multi_num;
for (size_t i = 0; i < stack_encoded_feature_.size(); ++i) {
int j = 1;
const LoopBlockFeature& loop_feature = stack_encoded_feature_[i];
int loop_prod = 1;
int parent_prod = 1;
if (i != 0) {
parent_prod = iter_multi_num[parent_indices_[i]];
loop_prod = parent_prod * loop_feature.loop_length;
}
iter_multi_num.push_back(loop_prod);
ret[j] += (loop_feature.float_add_or_sub * loop_prod);
++j;
ret[j] += (loop_feature.float_mul * loop_prod);
++j;
ret[j] += (loop_feature.float_div_or_mod * loop_prod);
++j;
ret[j] += (loop_feature.float_cmp * loop_prod);
++j;
ret[j] += (loop_feature.float_math_func * loop_prod);
++j;
ret[j] += (loop_feature.float_other_call * loop_prod);
++j;
ret[j] += (loop_feature.int_add_or_sub * loop_prod);
++j;
ret[j] += (loop_feature.int_mul * loop_prod);
++j;
ret[j] += (loop_feature.int_div_or_mod * loop_prod);
++j;
ret[j] += (loop_feature.int_cmp * loop_prod);
++j;
ret[j] += (loop_feature.int_math_func * loop_prod);
++j;
ret[j] += (loop_feature.int_other_call * loop_prod);
++j;
ret[j] += (loop_feature.bool_op * loop_prod);
++j;
ret[j] += (loop_feature.select_op * loop_prod);
++j;
ret[j] += (loop_feature.mem_alloc * loop_prod);
++j;
ret[j] += (loop_feature.mem_free * loop_prod);
++j;
ret[j] += (loop_feature.mem_read * loop_prod);
++j;
ret[j] += (loop_feature.mem_write * loop_prod);
++j;
ret[j] += (loop_feature.float_reduce_sum_or_sub * loop_prod);
++j;
ret[j] += (loop_feature.float_reduce_mul * loop_prod);
++j;
ret[j] += (loop_feature.float_reduce_div * loop_prod);
++j;
ret[j] += (loop_feature.float_reduce_max_or_min * loop_prod);
++j;
ret[j] += (loop_feature.float_broadcast * loop_prod);
++j;
ret[j] += (loop_feature.int_reduce_sum_or_sub * loop_prod);
++j;
ret[j] += (loop_feature.int_reduce_mul * loop_prod);
++j;
ret[j] += (loop_feature.int_reduce_div * loop_prod);
++j;
ret[j] += (loop_feature.int_reduce_max_or_min * loop_prod);
++j;
ret[j] += (loop_feature.int_broadcast * loop_prod);
++j;
ret[j + static_cast<int>(loop_feature.loop_opt_type)] += 1;
j += LoopBlockFeature::kOptApplySize;
ret[j] += (loop_feature.len_blockIdx_x * parent_prod);
++j;
ret[j] += (loop_feature.len_blockIdx_y * parent_prod);
++j;
ret[j] += (loop_feature.len_blockIdx_z * parent_prod);
++j;
ret[j] += (loop_feature.len_threadIdx_x * parent_prod);
++j;
ret[j] += (loop_feature.len_threadIdx_y * parent_prod);
++j;
ret[j] += (loop_feature.len_threadIdx_z * parent_prod);
++j;
ret[j] += (loop_feature.len_vthread * parent_prod);
++j;
ret[j] += (loop_feature.vectorize_factor * parent_prod);
++j;
}
for (size_t i = 0; i < ret.size(); ++i) {
ret[i] = slog(ret[i]);
}
return ret;
}
void Feature::IntoLoopBlock() {
stack_encoded_feature_.emplace_back(LoopBlockFeature());
stack_encoded_feature_[current_loop_block_index_].num_sub_loops += 1;
parent_indices_.push_back(current_loop_block_index_);
current_loop_block_index_ = stack_encoded_feature_.size() - 1;
}
void Feature::ExitLoopBlock() {
current_loop_block_index_ = parent_indices_[current_loop_block_index_];
}
LoopBlockFeature& Feature::CurrentLoopBlock() {
return stack_encoded_feature_[current_loop_block_index_];
}
const LoopBlockFeature& Feature::CurrentLoopBlock() const {
return stack_encoded_feature_[current_loop_block_index_];
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <vector>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
/* Loop feature enums */
enum class ForOptimizeFeatureEnum : int {
kNone,
kGpuBind,
kParallel,
kUnroll,
kVectorize
};
/* function to scale feature numbers */
inline float slog(float x) {
return x < 0 ? std::log2(-x + 1) : std::log2(x + 1);
}
class LoopBlockFeature {
public:
// TODO(zhhsplendid): distinguish more types such as float16, float32,
// float64, etc. However speed the gap between float and int are larger than
// different bits, so we just distinguished int and float here
/* Arithmetic features */
int float_add_or_sub = 0;
int float_mul = 0;
int float_div_or_mod = 0;
int float_cmp = 0;
int float_math_func = 0;
int float_other_call = 0; // like simple assign, cast, etc.
int int_add_or_sub = 0;
int int_mul = 0;
int int_div_or_mod = 0;
int int_cmp = 0;
int int_math_func = 0;
int int_other_call = 0; // like simple assign, cast, etc.
int bool_op = 0;
int select_op = 0;
static constexpr int kArithSize = 6 * 2 + 2;
/**
* Buffer memory features, which is the number of memory operations.
* Note that different size of memory operation can have various speed,
* however the speed difference would be small in OS. A meticulous TODO
* may be collect operand sizes (like alloc size, write size, or so)
*/
int mem_alloc = 0;
int mem_free = 0;
int mem_read = 0;
int mem_write = 0;
static constexpr int kMemSize = 4;
/**
* Reduce and Broadcast features
*/
int float_reduce_sum_or_sub = 0;
int float_reduce_mul = 0;
int float_reduce_div = 0;
int float_reduce_max_or_min = 0;
int float_broadcast = 0;
int int_reduce_sum_or_sub = 0;
int int_reduce_mul = 0;
int int_reduce_div = 0;
int int_reduce_max_or_min = 0;
int int_broadcast = 0;
static constexpr int kReduceBroadcastSize = 10;
/* Loop type features */
// A TODO maybe add loop position (Inner, Outer, Middle) feature
ForOptimizeFeatureEnum loop_opt_type = ForOptimizeFeatureEnum::kNone;
static constexpr int kOptApplySize = 5;
/* Thread features if loop is optimized by GPU or CPU parallelism.
* Useless in other cases.
*/
int len_blockIdx_x = 0;
int len_blockIdx_y = 0;
int len_blockIdx_z = 0;
int len_threadIdx_x = 0;
int len_threadIdx_y = 0;
int len_threadIdx_z = 0;
int len_vthread = 0; // length of virtual thread
int vectorize_factor = 0;
static constexpr int kThreadFeatureSize = 8;
static constexpr int kTotalSize = kArithSize + kMemSize +
kReduceBroadcastSize + kOptApplySize +
kThreadFeatureSize;
/* Non-feature attributes, used to maintain during feature_extractor */
// Number to indicate the loop block inside current one
int num_sub_loops = 0;
// Number of repeats of this loop, -1 represents unknown
int loop_length = 1;
};
/**
* Feature of Expr. It is used in CostModel
*/
class Feature {
public:
Feature();
explicit Feature(const common::Target& target);
// Convert the various-length loop block features to fixed-size vector
std::vector<float> ToFixedSizeVector();
// Call when visit into a loop block to collect LoopBlockFeature
void IntoLoopBlock();
// Call when exit a loop block to collect LoopBlockFeature
void ExitLoopBlock();
// The current loop block which we should collect feature on
LoopBlockFeature& CurrentLoopBlock();
// The current loop block which we should collect feature on
const LoopBlockFeature& CurrentLoopBlock() const;
private:
// We treat a computation feature to be encoded as variable-length vector.
// The root compute block is not a loop, but we treat it as a size-1 loop.
// Blocks are encoded like a stack. Each LoopBlockFeature contains a
// num_sub_loops to indicate the next level sub-loop-block it contains.
//
// For example, code like:
//
// some_compute_0
// loop1 {
// some_compute_1
// loop2 {
// some_compute_2
// }
// }
//
// loop3 {
// some_compute_3
// }
//
// We go through the code and push loops into stack, then the features are
// encoded as [loop_block_feature_0, loop_block_feature_1,
// loop_block_feature_2, loop_block_feature_3] where loop_block_feature_i
// stores the features of some_compute_i (such as number of arithmetic
// operations)
//
// loop_block_feature_0.num_sub_loops = 2
// loop_block_feature_1.num_sub_loops = 1
// loop_block_feature_2.num_sub_loops = 0
// loop_block_feature_3.num_sub_loops = 0
std::vector<LoopBlockFeature> stack_encoded_feature_;
int current_loop_block_index_;
std::vector<int> parent_indices_;
common::Target target_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/feature_extractor.h"
#include <vector>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
namespace cinn {
namespace auto_schedule {
using namespace ::cinn::ir; // NOLINT
FeatureExtractor::FeatureExtractor() {}
void FeatureExtractor::Visit(const Expr *x) {
IRVisitorRequireReImpl::Visit(x);
}
Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr,
const common::Target &target) {
feature_ = Feature(target);
for (const ir::Expr &e : mod_expr.GetExprs()) {
Visit(&e);
}
return feature_;
}
#define VisitDoNothing(NodeType) \
void FeatureExtractor::Visit(const NodeType *x) { \
std::vector<const Expr *> sub_exprs = x->expr_fields(); \
for (const Expr *e : sub_exprs) { \
if (e->defined()) { \
Visit(e); \
} \
} \
}
VisitDoNothing(IntImm);
VisitDoNothing(UIntImm);
VisitDoNothing(FloatImm);
VisitDoNothing(StringImm);
VisitDoNothing(Block);
VisitDoNothing(_Module_);
VisitDoNothing(_Var_);
VisitDoNothing(_LoweredFunc_);
VisitDoNothing(ScheduleBlock);
VisitDoNothing(ScheduleBlockRealize);
VisitDoNothing(Ramp);
VisitDoNothing(_Buffer_);
VisitDoNothing(_BufferRange_);
#define NotVisitExprFields(NodeType) \
void FeatureExtractor::Visit(const NodeType *x) {}
NotVisitExprFields(_Tensor_)
#define VisitForDtypePattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \
if (x->type() == common::F32() || x->type() == common::F16() || \
x->type() == common::F64()) { \
feature_.CurrentLoopBlock().float_##member += x->type().lanes(); \
} else { \
feature_.CurrentLoopBlock().int_##member += x->type().lanes(); \
} \
std::vector<const Expr *> sub_exprs = x->expr_fields(); \
for (const Expr *e : sub_exprs) { \
if (e->defined()) { \
Visit(e); \
} \
} \
}
VisitForDtypePattern(Add, add_or_sub);
VisitForDtypePattern(Sub, add_or_sub);
VisitForDtypePattern(Minus, add_or_sub);
VisitForDtypePattern(Mul, mul);
VisitForDtypePattern(Div, div_or_mod);
VisitForDtypePattern(Mod, div_or_mod);
VisitForDtypePattern(FracOp, div_or_mod);
VisitForDtypePattern(EQ, cmp);
VisitForDtypePattern(NE, cmp);
VisitForDtypePattern(GT, cmp);
VisitForDtypePattern(GE, cmp);
VisitForDtypePattern(LT, cmp);
VisitForDtypePattern(LE, cmp);
VisitForDtypePattern(Call, math_func);
VisitForDtypePattern(PrimitiveNode, math_func);
VisitForDtypePattern(Cast, other_call);
VisitForDtypePattern(Let, other_call);
#define VisitForMultiOperandsDtypePattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \
if (x->type() == common::F32() || x->type() == common::F16() || \
x->type() == common::F64()) { \
feature_.CurrentLoopBlock().float_##member += \
(x->operands().size() - 1); \
} else { \
feature_.CurrentLoopBlock().int_##member += (x->operands().size() - 1); \
} \
std::vector<const Expr *> sub_exprs = x->expr_fields(); \
for (const Expr *e : sub_exprs) { \
if (e->defined()) { \
Visit(e); \
} \
} \
}
VisitForMultiOperandsDtypePattern(Sum, add_or_sub);
VisitForMultiOperandsDtypePattern(Product, mul);
#define VisitCountMemberPattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \
feature_.CurrentLoopBlock().member += 1; \
std::vector<const Expr *> sub_exprs = x->expr_fields(); \
for (const Expr *e : sub_exprs) { \
if (e->defined()) { \
Visit(e); \
} \
} \
}
VisitCountMemberPattern(And, bool_op);
VisitCountMemberPattern(Or, bool_op);
VisitCountMemberPattern(Not, bool_op);
VisitCountMemberPattern(Max, select_op);
VisitCountMemberPattern(Min, select_op);
VisitCountMemberPattern(IfThenElse, select_op);
VisitCountMemberPattern(Select, select_op);
VisitCountMemberPattern(Alloc, mem_alloc);
VisitCountMemberPattern(Free, mem_free);
VisitCountMemberPattern(Load, mem_read);
VisitCountMemberPattern(Store, mem_write);
/* Visit for loops */
void FeatureExtractor::Visit(const For *x) {
feature_.IntoLoopBlock();
LoopBlockFeature &loop_feature = feature_.CurrentLoopBlock();
if (x->min.is_constant() && x->extent.is_constant()) {
loop_feature.loop_length =
(x->extent.get_constant() - x->min.get_constant());
} else {
loop_feature.loop_length = -1; // -1 represents unknown
}
if (x->is_parallel()) {
loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kParallel;
loop_feature.len_vthread = loop_feature.loop_length;
} else if (x->is_unrolled()) {
loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kUnroll;
} else if (x->is_vectorized()) {
loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kVectorize;
loop_feature.vectorize_factor = x->vectorize_info().factor;
} else if (x->is_binded()) {
loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kGpuBind;
const BindInfo &bind_info = x->bind_info();
int offset = bind_info.offset;
if (bind_info.for_type == ForType::GPUBlock) {
if (offset == 0) {
loop_feature.len_blockIdx_x = loop_feature.loop_length;
} else if (offset == 1) {
loop_feature.len_blockIdx_y = loop_feature.loop_length;
} else if (offset == 2) {
loop_feature.len_blockIdx_z = loop_feature.loop_length;
}
} else if (bind_info.for_type == ForType::GPUThread) {
if (offset == 0) {
loop_feature.len_threadIdx_x = loop_feature.loop_length;
} else if (offset == 1) {
loop_feature.len_threadIdx_y = loop_feature.loop_length;
} else if (offset == 2) {
loop_feature.len_threadIdx_z = loop_feature.loop_length;
}
}
}
std::vector<const Expr *> sub_exprs = x->expr_fields();
for (const Expr *e : sub_exprs) {
Visit(e);
}
feature_.ExitLoopBlock();
}
void FeatureExtractor::Visit(const PolyFor *x) {
Expr copy = optim::IRCopy(Expr(x));
feature_.IntoLoopBlock();
optim::TransformPolyForToFor(&copy);
ir::For *loop = copy.As<For>();
CHECK(loop != nullptr);
Visit(loop);
feature_.ExitLoopBlock();
}
/* Visit for Reduce and Broadcast */
void FeatureExtractor::Visit(const Reduce *x) {
if (x->type() == common::F32() || x->type() == common::F16() ||
x->type() == common::F64()) {
switch (x->reduce_type) {
case Reduce::ReduceType::kSum:
feature_.CurrentLoopBlock().float_reduce_sum_or_sub +=
x->type().lanes();
break;
case Reduce::ReduceType::kSub:
feature_.CurrentLoopBlock().float_reduce_sum_or_sub +=
x->type().lanes();
break;
case Reduce::ReduceType::kDiv:
feature_.CurrentLoopBlock().float_reduce_div += x->type().lanes();
break;
case Reduce::ReduceType::kMul:
feature_.CurrentLoopBlock().float_reduce_mul += x->type().lanes();
break;
case Reduce::ReduceType::kMax:
feature_.CurrentLoopBlock().float_reduce_max_or_min +=
x->type().lanes();
break;
case Reduce::ReduceType::kMin:
feature_.CurrentLoopBlock().float_reduce_max_or_min +=
x->type().lanes();
break;
}
} else {
switch (x->reduce_type) {
case Reduce::ReduceType::kSum:
feature_.CurrentLoopBlock().int_reduce_sum_or_sub += x->type().lanes();
break;
case Reduce::ReduceType::kSub:
feature_.CurrentLoopBlock().int_reduce_sum_or_sub += x->type().lanes();
break;
case Reduce::ReduceType::kDiv:
feature_.CurrentLoopBlock().int_reduce_div += x->type().lanes();
break;
case Reduce::ReduceType::kMul:
feature_.CurrentLoopBlock().int_reduce_mul += x->type().lanes();
break;
case Reduce::ReduceType::kMax:
feature_.CurrentLoopBlock().int_reduce_max_or_min += x->type().lanes();
break;
case Reduce::ReduceType::kMin:
feature_.CurrentLoopBlock().int_reduce_max_or_min += x->type().lanes();
break;
}
}
std::vector<const Expr *> sub_exprs = x->expr_fields();
for (const Expr *e : sub_exprs) {
Visit(e);
}
}
VisitForDtypePattern(Broadcast, broadcast);
/* Visit for IntrinsicOp */
void FeatureExtractor::Visit(const IntrinsicOp *x) {
switch (x->getKind()) {
#define __(op__) \
case IntrinsicKind::k##op__: \
Visit(llvm::dyn_cast<intrinsics::op__>(x)); \
break;
INTRINSIC_KIND_FOR_EACH(__)
#undef __
}
}
VisitDoNothing(intrinsics::BufferGetDataHandle);
VisitDoNothing(intrinsics::BufferGetDataConstHandle);
VisitDoNothing(intrinsics::PodValueToX);
VisitDoNothing(intrinsics::BufferCreate);
VisitDoNothing(intrinsics::GetAddr);
VisitDoNothing(intrinsics::ArgsConstruct);
VisitForDtypePattern(intrinsics::BuiltinIntrin, other_call)
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/cost_model/feature.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace cinn {
namespace auto_schedule {
class FeatureExtractor : public ir::IRVisitorRequireReImpl<void> {
public:
FeatureExtractor();
Feature Extract(const ir::ModuleExpr& mod_expr, const common::Target& target);
void Visit(const Expr* x) override;
#define __(op__) void Visit(const ir::op__* x) override;
NODETY_FORALL(__)
#undef __
#define __(op__) virtual void Visit(const ir::intrinsics::op__* x);
INTRINSIC_KIND_FOR_EACH(__)
#undef __
private:
Feature feature_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/feature_extractor.h"
#include <gtest/gtest.h>
#include <pybind11/embed.h>
#include <cmath>
#include <unordered_set>
#include <vector>
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/poly/stage.h"
namespace cinn {
namespace auto_schedule {
TEST(FeatureExtractor, SimpleAssign) {
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
ir::Expr M(32);
ir::Expr N(32);
lang::Placeholder<float> A("A", {M, N});
ir::Tensor B = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr to test: " << ast_expr;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
FeatureExtractor extractor;
Feature feature = extractor.Extract(mod_expr, target);
std::vector<float> to_check = feature.ToFixedSizeVector();
ASSERT_EQ(to_check.size(),
static_cast<size_t>(LoopBlockFeature::kTotalSize + 1));
VLOG(6) << "Feature data before slog:";
for (size_t i = 0; i < to_check.size(); ++i) {
VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1);
if (i != 0 && i != 17 && i != 18 && i != 29) {
ASSERT_EQ(to_check[i], 0);
}
}
// target
#ifdef CINN_WITH_CUDA
ASSERT_EQ(to_check[0], 1);
#else
ASSERT_EQ(to_check[0], 0);
#endif
// mem_read
ASSERT_EQ(to_check[17],
slog(M.get_constant() * N.get_constant())); // mem_read
// mem_write
ASSERT_EQ(to_check[18],
slog(M.get_constant() * N.get_constant())); // mem_write
// non-opt loops, including root block
ASSERT_EQ(to_check[29], slog(3));
}
TEST(FeatureExtractor, MatrixMultiply) {
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
ir::Expr M(2);
ir::Expr N(2);
ir::Expr K(4);
lang::Placeholder<float> A("A", {M, K});
lang::Placeholder<float> B("B", {K, N});
ir::Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = lang::Compute(
{M, N},
[&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = poly::CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true);
std::vector<Expr> vec_ast{funcs[0]->body};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
std::vector<ir::Expr> blocks = ir_sch.GetAllBlocks();
std::vector<ir::Expr> loops = ir_sch.GetLoops(blocks[0]);
ir_sch.Bind(loops.back(), "threadIdx.x");
ir::Expr ast_expr = mod_expr.GetExprs()[0];
VLOG(6) << "Expr to test: " << ast_expr;
FeatureExtractor extractor;
Feature feature = extractor.Extract(mod_expr, target);
std::vector<float> to_check = feature.ToFixedSizeVector();
ASSERT_EQ(to_check.size(),
static_cast<size_t>(LoopBlockFeature::kTotalSize + 1));
std::unordered_set<size_t> non_zero_indice = {0, 1, 2, 17, 18, 29, 30, 37};
for (size_t i = 0; i < to_check.size(); ++i) {
VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1);
if (!non_zero_indice.count(i)) {
ASSERT_EQ(to_check[i], 0);
}
}
// target
#ifdef CINN_WITH_CUDA
ASSERT_EQ(to_check[0], 1);
#else
ASSERT_EQ(to_check[0], 0);
#endif
float out_loop = M.get_constant() * N.get_constant();
float total_loop = out_loop * K.get_constant();
// float_mul
ASSERT_EQ(to_check[1], slog(total_loop));
// float_add_or_sub
ASSERT_EQ(to_check[2], slog(total_loop));
// mem_read
ASSERT_EQ(to_check[17], slog(total_loop * 3));
// mem_write
ASSERT_EQ(to_check[18], slog(total_loop + out_loop));
// non-opt loops, including root block
ASSERT_EQ(to_check[29], slog(3));
// GpuBind loop
ASSERT_EQ(to_check[30], slog(1));
// GpuBind loop
ASSERT_EQ(to_check[37], slog(out_loop));
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/feature.h"
#include <gtest/gtest.h>
#include <pybind11/embed.h>
namespace cinn {
namespace auto_schedule {
TEST(Feature, Basic) {
// TODO(zhhsplendid): add some basic tests
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h"
#include <dirent.h>
#include <glog/logging.h>
#include <pybind11/embed.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <atomic>
#include <cassert>
#include <cstring>
#include <iostream>
#include <memory>
#include <mutex>
#include <regex>
#include <string>
#include <utility>
#include <vector>
#include "paddle/cinn/common/python_interpreter_guard.h"
namespace cinn {
namespace auto_schedule {
std::atomic<int> XgbCostModel::xgb_cost_model_count_(0);
// Convert 1D vector to py numpy
template <typename Dtype>
pybind11::array VectorToNumpy(const std::vector<Dtype>& vec) {
return pybind11::array(pybind11::cast(vec));
}
// Convert 2D vector to py numpy
template <typename Dtype>
pybind11::array VectorToNumpy(const std::vector<std::vector<Dtype>>& vec) {
if (vec.size() == 0) {
return pybind11::array(pybind11::dtype::of<Dtype>(), {0, 0});
}
std::vector<size_t> shape{vec.size(), vec[0].size()};
pybind11::array ret(pybind11::dtype::of<Dtype>(), shape);
Dtype* py_data = static_cast<Dtype*>(ret.mutable_data());
for (size_t i = 0; i < vec.size(); ++i) {
assert(vec[i].size() == shape[1] &&
"Sub vectors must have same size in VectorToNumpy");
memcpy(py_data + (shape[1] * i), vec[i].data(), shape[1] * sizeof(Dtype));
}
return ret;
}
// the Pybind default Python interpreter doesn't contain some paths in
// sys.path, so we have to add it.
//
// Note: the Pybind default Python interpreter only uses default Python.
// Something may be wrong when users use virtual Python environment.
void AddDistPkgToPythonSysPath() {
pybind11::module sys_py_mod = pybind11::module::import("sys");
// short version such as "3.7", "3.8", ...
std::string py_short_version =
sys_py_mod.attr("version").cast<std::string>().substr(0, 3);
std::string site_pkg_str =
"/usr/local/lib/python" + py_short_version + "/dist-packages";
sys_py_mod.attr("path").attr("append")(site_pkg_str);
// TODO(zhhsplendid): warning to users if setuptools hasn't been installed
DIR* site_pkg_dir = opendir(site_pkg_str.c_str());
if (site_pkg_dir != nullptr) {
std::regex setuptool_regex("setuptools-.*-py" + py_short_version +
"\\.egg");
struct dirent* entry = nullptr;
while ((entry = readdir(site_pkg_dir)) != nullptr) {
if (std::regex_match(entry->d_name, setuptool_regex)) {
sys_py_mod.attr("path").attr("append")(site_pkg_str + "/" +
entry->d_name);
}
}
closedir(site_pkg_dir);
}
}
XgbCostModel::XgbCostModel() {
common::PythonInterpreterGuard::Guard();
int previous = xgb_cost_model_count_.fetch_add(1);
if (previous == 0) {
AddDistPkgToPythonSysPath();
}
xgb_module_ = pybind11::module::import("xgboost");
xgb_booster_ = xgb_module_.attr("Booster")();
}
void XgbCostModel::Train(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) {
update_samples_ = samples;
update_labels_ = labels;
pybind11::array np_samples = VectorToNumpy<float>(samples);
pybind11::array np_labels = VectorToNumpy<float>(labels);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels);
xgb_booster_ = xgb_module_.attr("train")(
pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_));
}
std::vector<float> XgbCostModel::Predict(
const std::vector<std::vector<float>>& samples) const {
pybind11::array np_samples = VectorToNumpy<float>(samples);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples);
pybind11::array py_result = xgb_booster_.attr("predict")(dmatrix);
return py_result.cast<std::vector<float>>();
}
void XgbCostModel::Update(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) {
update_samples_.insert(update_samples_.end(), samples.begin(), samples.end());
update_labels_.insert(update_labels_.end(), labels.begin(), labels.end());
pybind11::array np_samples = VectorToNumpy<float>(update_samples_);
pybind11::array np_labels = VectorToNumpy<float>(update_labels_);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels);
xgb_booster_ = xgb_module_.attr("train")(
pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_));
}
void XgbCostModel::Save(const std::string& path) {
xgb_booster_.attr("save_model")(pybind11::str(path));
}
void XgbCostModel::Load(const std::string& path) {
xgb_booster_.attr("load_model")(pybind11::str(path));
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <pybind11/embed.h>
#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "paddle/cinn/common/cost_model.h"
namespace cinn {
namespace auto_schedule {
/**
* A C++ cost model which calls Python xgboost via pybind
*
* Note: this class handles Python interpreter life time in class.
* If you have to call other Python functions out of this class so that meet
* life time conflict, you can check cinn::common::PythonInterpreterGuard
*
* For cinn::common::PythonInterpreterGuard, see:
* cinn/common/python_interpreter_guard.h .cc
*
* For pybind interpreter lifetime management, see:
*
* https://pybind11.readthedocs.io/en/stable/advanced/embedding.html#interpreter-lifetime
* https://pybind11.readthedocs.io/en/stable/reference.html#_CPPv422initialize_interpreterbiPPCKcb
*/
class XgbCostModel : public CostModel {
public:
XgbCostModel();
~XgbCostModel() = default;
void Train(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) override;
std::vector<float> Predict(
const std::vector<std::vector<float>>& samples) const override;
void Update(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) override;
void Save(const std::string& path) override;
void Load(const std::string& path) override;
private:
// Python xgboost module
pybind11::module xgb_module_;
// Object points to Python xgb.Booster()
pybind11::object xgb_booster_;
// atomic int to handle python interpreter lifetime and package dependency
static std::atomic<int> xgb_cost_model_count_;
// Default train rounds
static constexpr int kTrainRound_ = 10;
std::vector<std::vector<float>> update_samples_;
std::vector<float> update_labels_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <pybind11/embed.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <vector>
namespace cinn {
namespace auto_schedule {
TEST(CostModel, Basic) {
XgbCostModel cost_model;
srand(time(NULL));
int batch_size = 16;
int feature_size = 8;
std::vector<float> labels(batch_size, 1.0);
std::vector<std::vector<float>> samples(batch_size,
std::vector<float>(feature_size));
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < feature_size; ++j) {
samples[i][j] = rand() % 10; // NOLINT
}
}
cost_model.Train(samples, labels);
std::vector<float> pred = cost_model.Predict(samples);
std::string path = "./test_cost_model.cpp_save_model";
cost_model.Save(path);
XgbCostModel load_cost_model;
load_cost_model.Load(path);
std::vector<float> load_pred = cost_model.Predict(samples);
ASSERT_EQ(pred.size(), load_pred.size());
for (size_t i = 0; i < pred.size(); ++i) {
ASSERT_FLOAT_EQ(pred[i], load_pred[i]);
VLOG(6) << "pred[" << i << "] = " << pred[i];
}
std::remove(path.c_str());
cost_model.Update(samples, labels);
pred = cost_model.Predict(samples);
for (size_t i = 0; i < pred.size(); ++i) {
VLOG(6) << "pred[" << i << "] = " << pred[i];
}
}
} // namespace auto_schedule
} // namespace cinn
core_gather_headers()
gather_srcs(cinnapi_src SRCS database.cc jsonfile_database.cc)
cinn_cc_test(test_database SRCS database_test.cc DEPS cinncore)
cinn_cc_test(test_jsonfile_database SRCS jsonfile_database_test.cc DEPS
cinncore)
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/database/database.h"
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <google/protobuf/util/json_util.h>
#include "paddle/cinn/auto_schedule/database/jsonfile_database.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/schedule_desc.h"
namespace cinn {
namespace auto_schedule {
bool TuningRecord::Compare::operator()(const TuningRecord& lhs,
const TuningRecord& rhs) const {
return lhs.execution_cost < rhs.execution_cost;
}
proto::TuningRecord TuningRecord::ToProto() const {
proto::TuningRecord record_proto;
record_proto.set_task_key(task_key);
record_proto.set_execution_cost(execution_cost);
record_proto.set_predicted_cost(predicted_cost);
record_proto.mutable_trace()->CopyFrom(trace);
return record_proto;
}
Database::Database(int capacity_per_task)
: capacity_per_task_(capacity_per_task) {
CHECK_GT(capacity_per_task_, 0)
<< "capacity_per_task_ should be greater than 0";
}
std::unique_ptr<Database> Database::Make(const DatabaseConfig& config) {
if (config.type == DatabaseType::kMemory) {
return std::make_unique<Database>(config.capacity_per_task);
} else if (config.type == DatabaseType::kJSONFile) {
return std::make_unique<JSONFileDatabase>(
config.capacity_per_task, config.record_file_path, true);
}
LOG(FATAL) << "Unimplemented database type.";
return nullptr;
}
void Database::Insert(const TuningRecord& record) {
auto& records = key2record_[record.task_key];
records.emplace(record);
if (records.size() > capacity_per_task_) {
records.erase(std::prev(records.end()));
}
}
bool Database::AddRecord(const TuningRecord& record) {
CHECK(!record.task_key.empty()) << "task_key of TuningRecord can't be empty";
Insert(record);
return Commit(record);
}
std::vector<TuningRecord> Database::LookUp(const std::string& task_key) {
auto fit = key2record_.find(task_key);
if (fit == key2record_.end()) {
return {};
}
std::vector<TuningRecord> results;
results.reserve(fit->second.size());
results.assign(fit->second.begin(), fit->second.end());
return results;
}
std::vector<TuningRecord> Database::GetTopK(const std::string& task_key,
int k) {
auto fit = key2record_.find(task_key);
if (fit == key2record_.end() || k <= 0) {
return {};
}
if (k > capacity_per_task_) {
LOG(WARNING) << "Top k=" << k
<< " is greater than the capacity, will adjust k="
<< capacity_per_task_;
k = capacity_per_task_;
}
std::vector<TuningRecord> results;
results.reserve(k);
for (const TuningRecord& record : fit->second) {
results.emplace_back(record);
if (results.size() == k) {
break;
}
}
return results;
}
size_t Database::Size() {
auto res = std::accumulate(key2record_.begin(),
key2record_.end(),
size_t(0),
[](size_t res, const auto& kv) -> size_t {
return std::move(res) + kv.second.size();
});
return res;
}
size_t Database::Count(const std::string& task_key) {
auto fit = key2record_.find(task_key);
if (fit == key2record_.end()) {
return 0;
}
return fit->second.size();
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#include "paddle/cinn/auto_schedule/auto_schedule.pb.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/ir/schedule/schedule_desc.pb.h"
namespace cinn {
namespace auto_schedule {
// Record related data about tuning process of a measure candidate
struct TuningRecord {
// the unique key to identify a task
std::string task_key;
// the predicted cost of CostModel
float predicted_cost; // unit: us
// the ScheduleDesc of this tuning process
ir::proto::ScheduleDesc trace;
// the cost time of the candidate executed during measure
double execution_cost; // unit: us
TuningRecord() = default;
explicit TuningRecord(const proto::TuningRecord& record)
: task_key(record.task_key()),
predicted_cost(record.predicted_cost()),
trace(record.trace()),
execution_cost(record.execution_cost()) {}
TuningRecord(const std::string& task_key,
const SearchState& state,
double execution_cost)
: task_key(task_key),
predicted_cost(state->predicted_cost),
trace(state->ir_schedule.GetTraceDesc().ToProto()),
execution_cost(execution_cost) {}
// convert to proto object
proto::TuningRecord ToProto() const;
// a binary compare function that denotes when the left
// will be sorted in the front of the right
struct Compare {
bool operator()(const TuningRecord& lhs, const TuningRecord& rhs) const;
};
};
enum class DatabaseType : int { kMemory, kJSONFile };
struct DatabaseConfig {
DatabaseType type = DatabaseType::kMemory;
int capacity_per_task = 2;
std::string record_file_path = "/tmp/tuning_record.json";
};
// A database supports insert or lookup historial tuning result with specified
// traits. It can be implemented with a concrete storage to save/load underlying
// data, such as memory, file, database server and so on, this base class can be
// regarded as one using memory as its underlying storage medium.
class Database {
public:
explicit Database(int capacity_per_task);
~Database() = default;
// Create a Database with the specific config
static std::unique_ptr<Database> Make(const DatabaseConfig& config);
// add a record into the database
bool AddRecord(const TuningRecord& record);
// return all records whose task_keys are equal to the specified key
std::vector<TuningRecord> LookUp(const std::string& task_key);
// return the states of the top k in sorted candidates
std::vector<TuningRecord> GetTopK(const std::string& task_key, int k);
// return the total number of stored candidates
size_t Size();
// return the number of stored candidates with specified key
size_t Count(const std::string& task_key);
protected:
// commit the newly added record into underlying storage
virtual bool Commit(const TuningRecord& record) { return true; }
// insert a newly added record into memory storage
void Insert(const TuningRecord& record);
// map task_key to its records
std::unordered_map<std::string,
std::multiset<TuningRecord, TuningRecord::Compare>>
key2record_;
// the max number of candidates stored
const int capacity_per_task_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/database/database.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/cinn/auto_schedule/auto_schedule.pb.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
class TestDatabase : public ::testing::Test {
public:
TestDatabase() : test_db(2) {
auto state = SearchState(ir::IRSchedule());
test_db.AddRecord(TuningRecord("k1", state, 1.0));
test_db.AddRecord(TuningRecord("k2", state, 2.0));
test_db.AddRecord(TuningRecord("k2", state, 3.0));
test_db.AddRecord(TuningRecord("k3", state, 3.0));
test_db.AddRecord(TuningRecord("k3", state, 4.0));
test_db.AddRecord(TuningRecord("k3", state, 5.0));
test_db.AddRecord(TuningRecord("k4", state, 4.0));
}
void SetUp() override {}
Database test_db;
};
TEST_F(TestDatabase, Basic) {
ASSERT_EQ(test_db.Size(), 6);
auto records = test_db.LookUp("k3");
// check the max number of stored candidates will
// be restricted to capacity_per_task
ASSERT_EQ(test_db.Count("k3"), 2);
ASSERT_EQ(records.size(), 2);
EXPECT_EQ(records[0].execution_cost, 3.0);
EXPECT_EQ(records[1].execution_cost, 4.0);
}
TEST_F(TestDatabase, GetTopK) {
ASSERT_TRUE(test_db.GetTopK("k5", 2).empty());
ASSERT_EQ(test_db.GetTopK("k4", 3).size(), 1);
test_db.AddRecord(
TuningRecord("k4", SearchState(ir::IRSchedule(), 1.2), 2.0));
test_db.AddRecord(
TuningRecord("k4", SearchState(ir::IRSchedule(), 1.0), 3.0));
auto records = test_db.GetTopK("k4", 3);
ASSERT_EQ(records.size(), 2);
EXPECT_FLOAT_EQ(records[0].predicted_cost, 1.2);
EXPECT_FLOAT_EQ(records[1].predicted_cost, 1.0);
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/database/jsonfile_database.h"
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <google/protobuf/util/json_util.h>
#include <fstream>
#include "paddle/cinn/auto_schedule/auto_schedule.pb.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/utils/multi_threading.h"
namespace cinn {
namespace auto_schedule {
// append a line to file
void AppendLineToFile(const std::string& file_path, const std::string& line) {
std::ofstream os(file_path, std::ofstream::app);
CHECK(os.good()) << "Cannot open the file to write: " << file_path;
os << line << std::endl;
}
// read lines from a json file
std::vector<std::string> ReadLinesFromFile(const std::string& file_path,
bool allow_new_file) {
std::ifstream is(file_path);
if (is.good()) {
std::vector<std::string> json_strs;
for (std::string str; std::getline(is, str);) {
json_strs.push_back(str);
}
return json_strs;
}
CHECK(allow_new_file) << "File doesn't exist: " << file_path;
std::ofstream os(file_path);
CHECK(os.good()) << "Cannot create new file: " << file_path;
return {};
}
JSONFileDatabase::JSONFileDatabase(int capacity_per_task,
const std::string& record_file_path,
bool allow_new_file)
: Database(capacity_per_task), record_file_path_(record_file_path) {
VLOG(3) << "Auto schedule will save/load tuning records on file:"
<< record_file_path;
auto json_lines = ReadLinesFromFile(record_file_path_, allow_new_file);
std::vector<cinn::auto_schedule::proto::TuningRecord> all_records_proto(
json_lines.size());
// convert JSON string to proto object
auto worker_fn = [this, &json_lines, &all_records_proto](int index) {
cinn::auto_schedule::proto::TuningRecord record_proto;
auto status = google::protobuf::util::JsonStringToMessage(json_lines[index],
&record_proto);
CHECK(status.ok()) << "Failed to parse JSON: " << json_lines[index];
all_records_proto[index].Swap(&record_proto);
};
utils::parallel_run(
worker_fn, utils::SequenceDispatcher(0, json_lines.size()), -1);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (const auto& record_proto : all_records_proto) {
std::string task_key = record_proto.task_key();
if (task_registry->Has(task_key)) {
VLOG(4) << "Add a measured TuningRecord with task_key=" << task_key;
Insert(TuningRecord(record_proto));
}
}
}
// convert a TuningRecord object to string in JSON format
std::string JSONFileDatabase::RecordToJSON(const TuningRecord& record) {
proto::TuningRecord record_proto = record.ToProto();
std::string json_string;
auto status =
google::protobuf::util::MessageToJsonString(record_proto, &json_string);
CHECK(status.ok()) << "Failed to serialize record to JSON, task key = "
<< record.task_key;
VLOG(4) << "json_string = \n" << json_string;
return json_string;
}
bool JSONFileDatabase::Commit(const TuningRecord& record) {
std::string json_string = RecordToJSON(record);
AppendLineToFile(record_file_path_, json_string);
return true;
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/database/database.h"
namespace cinn {
namespace auto_schedule {
// JSONFileDatabase is a database implemented by JSON file to save/load
// underlying data.
class JSONFileDatabase : public Database {
public:
/*!
* \brief Build a JSONFileDatabase object from a json file.
* \param capacity_per_task The max number of candidates stored.
* \param record_file_path The path of the json file.
* \param allow_new_file Whether to create new file when the given path is not
* found.
*/
JSONFileDatabase(int capacity_per_task,
const std::string& record_file_path,
bool allow_new_file);
~JSONFileDatabase() = default;
// convert a TuningRecord object to string in JSON format
std::string RecordToJSON(const TuningRecord& record);
protected:
// commit the newly added record into json file
bool Commit(const TuningRecord& record) override;
// the name of the json file to save tuning records.
std::string record_file_path_;
};
// append a line to file
void AppendLineToFile(const std::string& file_path, const std::string& line);
// read lines from a json file
std::vector<std::string> ReadLinesFromFile(const std::string& file_path,
bool allow_new_file = true);
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/database/jsonfile_database.h"
#include <google/protobuf/util/message_differencer.h>
#include <gtest/gtest.h>
#include <fstream>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace auto_schedule {
// Return lowerd ir AST for example functions used in this test
std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape,
const Target& target) {
CHECK(shape.size() == 2) << "shape should be 2";
std::vector<Expr> domain;
for (auto i = 0; i < shape.size(); ++i) {
domain.emplace_back(shape[i]);
}
Placeholder<float> A("A", domain);
ir::Tensor B, C;
B = Compute(
domain, [&A](Var i, Var j) { return A(i, j); }, "B");
C = Compute(
domain, [&B](Var i, Var j) { return B(i, j); }, "C");
return cinn::lang::LowerVec(
"test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true);
}
// Create a new IRSchedule with copied ir::LoweredFunc AST
ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs,
const std::string& task_key) {
std::vector<Expr> exprs;
for (auto&& func : lowered_funcs) {
exprs.emplace_back(optim::IRCopy(func->body));
}
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
task_registry->Regist(task_key, ir::ModuleExpr(exprs));
return ir::IRSchedule(ir::ModuleExpr(exprs));
}
class TestJSONFileDatabase : public ::testing::Test {
public:
TestJSONFileDatabase()
: record_file_path("/tmp/test_record.json"),
test_db(2, record_file_path, true) {}
void SetUp() override { lowered_funcs = LowerCompute({32, 32}, target); }
void TearDown() override {
auto isFileExists = [](const std::string& file_path) -> bool {
std::ifstream f(file_path.c_str());
return f.good();
};
if (isFileExists(record_file_path)) {
if (remove(record_file_path.c_str()) == 0) {
LOG(INFO) << "Successfully deleted file: " << record_file_path;
} else {
LOG(INFO) << "failed to delete file: " << record_file_path;
}
} else {
LOG(INFO) << "file: " << record_file_path << "does not exist.";
}
}
std::string record_file_path;
JSONFileDatabase test_db;
std::vector<ir::LoweredFunc> lowered_funcs;
Target target = common::DefaultHostTarget();
};
TEST_F(TestJSONFileDatabase, Serialize) {
ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "test");
auto fused = ir_sch.Fuse("B", {0, 1});
VLOG(3) << "after Fuse, Expr: " << fused;
TuningRecord record1("test", SearchState(std::move(ir_sch), 2.0), 1.0);
std::string str = test_db.RecordToJSON(record1);
VLOG(3) << "RecordToJSON: " << str;
// Because the serialization of protobuf does not guarantee the order, we give
// all possible results.
std::string case1 =
"{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":"
"{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":"
"\"INTS\",\"ints\":[0,1]},{\"name\":\"block_"
"name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}";
std::string case2 =
"{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":"
"{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":"
"\"STRING\",\"s\":\"B\"},{\"name\":\"loops_"
"index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}";
EXPECT_EQ(true, str == case1 || str == case2);
}
TEST_F(TestJSONFileDatabase, SaveLoad) {
ir::IRSchedule ir_sch1 = MakeIRSchedule(lowered_funcs, "k1");
auto fused1 = ir_sch1.Fuse("B", {0, 1});
ir::IRSchedule ir_sch2 = MakeIRSchedule(lowered_funcs, "k2");
test_db.AddRecord(
TuningRecord("k1", SearchState(std::move(ir_sch1), 1.5), 1.0));
test_db.AddRecord(
TuningRecord("k2", SearchState(std::move(ir_sch2), 3.5), 3.0));
std::vector<std::string> strs = ReadLinesFromFile(record_file_path);
ASSERT_EQ(strs.size(), 2);
// Because the serialization of protobuf does not guarantee the order, we give
// all possible results.
std::string case1 =
"{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":"
"{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":"
"\"INTS\",\"ints\":[0,1]},{\"name\":\"block_"
"name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}";
std::string case2 =
"{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":"
"{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":"
"\"STRING\",\"s\":\"B\"},{\"name\":\"loops_"
"index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}";
EXPECT_EQ(true, strs[0] == case1 || strs[0] == case2);
EXPECT_EQ(strs[1],
"{\"taskKey\":\"k2\",\"executionCost\":3,\"predictedCost\":3.5,"
"\"trace\":{}}");
}
TEST_F(TestJSONFileDatabase, Basic) {
test_db.AddRecord(TuningRecord(
"k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 8.0), 3.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 7.0), 4.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 6.0), 5.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 4.0));
ASSERT_EQ(test_db.Size(), 6);
auto records = test_db.LookUp("k3");
// check the max number of stored candidates will
// be restricted to capacity_per_task
ASSERT_EQ(test_db.Count("k3"), 2);
ASSERT_EQ(records.size(), 2);
EXPECT_EQ(records[0].execution_cost, 3.0);
EXPECT_EQ(records[1].execution_cost, 4.0);
}
TEST_F(TestJSONFileDatabase, GetTopK) {
test_db.AddRecord(TuningRecord(
"k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 3.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 4.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 5.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 2.0), 4.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.2), 2.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 3.0));
auto records = test_db.GetTopK("k4", 3);
ASSERT_EQ(records.size(), 2);
EXPECT_FLOAT_EQ(records[0].predicted_cost, 1.2);
EXPECT_FLOAT_EQ(records[1].predicted_cost, 1.0);
}
TEST_F(TestJSONFileDatabase, Reload) {
ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "k1");
auto fused = ir_sch.Fuse("B", {0, 1});
test_db.AddRecord(
TuningRecord("k1", SearchState(std::move(ir_sch), 1.0), 1.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
auto records = test_db.LookUp("k1");
ASSERT_EQ(records.size(), 1);
JSONFileDatabase new_db(2, record_file_path, false);
ASSERT_EQ(new_db.Size(), 2);
auto loaded_records = new_db.LookUp("k1");
ASSERT_EQ(records.size(), loaded_records.size());
EXPECT_EQ(records[0].task_key, loaded_records[0].task_key);
EXPECT_EQ(records[0].execution_cost, loaded_records[0].execution_cost);
EXPECT_EQ(records[0].predicted_cost, loaded_records[0].predicted_cost);
// check the equality of trace info between original TuningRecord and the
// loaded TuningRecord
const auto& lhs_trace = records[0].trace;
const auto& rhs_trace = loaded_records[0].trace;
google::protobuf::util::MessageDifferencer dif;
static const google::protobuf::Descriptor* descriptor =
cinn::ir::proto::ScheduleDesc_Step::descriptor();
dif.TreatAsSet(descriptor->FindFieldByName("attrs"));
EXPECT_TRUE(dif.Compare(lhs_trace, rhs_trace));
// check the equality of module expr between original TuningRecord
// and the loaded TuningRecord by replaying with tracing ScheduleDesc
ir::IRSchedule lhs_sch = MakeIRSchedule(lowered_funcs, "k1");
ir::IRSchedule rhs_sch = MakeIRSchedule(lowered_funcs, "k1");
ir::ScheduleDesc::ReplayWithProto(lhs_trace, &lhs_sch);
ir::ScheduleDesc::ReplayWithProto(rhs_trace, &rhs_sch);
auto lhs_exprs = lhs_sch.GetModule().GetExprs();
auto rhs_exprs = rhs_sch.GetModule().GetExprs();
ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size());
for (auto i = 0; i < lhs_exprs.size(); ++i) {
std::string lhs = utils::GetStreamCnt(lhs_exprs.at(i));
std::string rhs = utils::GetStreamCnt(rhs_exprs.at(i));
size_t remove_prefix_len = 28;
ASSERT_EQ(lhs.erase(0, remove_prefix_len), rhs.erase(0, remove_prefix_len));
}
}
} // namespace auto_schedule
} // namespace cinn
core_gather_headers()
gather_srcs(cinnapi_src SRCS schedule_measurer.cc simple_builder.cc
simple_runner.cc)
cinn_cc_test(test_simple_runner SRCS simple_runner_test.cc DEPS cinncore)
cinn_cc_test(test_measurer SRCS measurer_test.cc DEPS cinncore)
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
namespace cinn {
namespace auto_schedule {
// The input to a measurer
struct MeasureInput {
// The task object related to this measurement.
const TuneTask* task;
// lowered Exprs to be measured
std::vector<ir::LoweredFunc> lowered_funcs;
// It is used to pass for some arguments that maybe
// specified value in advance. default is null
const std::map<std::string, cinn_pod_value_t>* execution_args = nullptr;
};
// The result of a measurement
struct MeasureResult {
// The time cost of execution in average of running
// with a specific repeated times.
double execution_cost = 0.0; // unit: us
// The time cost of the whole measurement process including
// building and running
double elapsed_time = 0.0; // unit: us
// used to return detail messages once an error occurred during measurement,
// empty if nothing goes wrong
std::string error_msg;
};
// The result of building with input schedule
struct BuildResult {
// The scope that owns detail compilation infos of parameters in the runtime
// program
const hlir::framework::Scope* compiled_scope;
// The executable program
std::unique_ptr<hlir::framework::Program> runtime_program;
};
// This interface defines how to generate executable objects
// with input schedule. A builder should not contain stateful data
// related to any task so it can be called parallelly among multiple
// processes of task tuning.
class ScheduleBuilder {
public:
virtual BuildResult Build(const MeasureInput& input) = 0;
};
// This interface defines how to run the built result. Like above
// ScheduleBuilder, a runner shoule be implemented with not bound to a specific
// task.
class ScheduleRunner {
public:
virtual MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) = 0;
};
} // namespace auto_schedule
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment