Commit 992bec46 authored by “yuguo”'s avatar “yuguo”
Browse files

2.5

parent 0259837d
add_subdirectory(utils)
add_subdirectory(scripts)
add_subdirectory(testing)
set(PYTHON_TESTS_DIR set(PYTHON_TESTS_DIR
${PADDLE_BINARY_DIR}/python/paddle/fluid/tests ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests
CACHE INTERNAL "python tests directory") CACHE INTERNAL "python tests directory")
add_subdirectory(utils)
add_subdirectory(ir)
add_subdirectory(scripts)
add_subdirectory(testing)
add_subdirectory(phi) add_subdirectory(phi)
add_subdirectory(infrt)
add_subdirectory(fluid) add_subdirectory(fluid)
# NOTE(zhiqiu): The changes of cc tests
# Before, (1) the source file of cc tests are distributed in different sub-directories,
# (2) the tests are added and configured by calling `cc_test()` in each `CMakeLists.txt`,
# (3) the tests links static libraries of paddle modules,
# (4) the tests binaries are generated in different directories, as the same as the
# folder of source file.
# Now, we want to make all cc tests dynamically linked to the main paddle library,
# i.e., `libpaddle.so`, so we changes the logic of (2), (3), (4):
# (2) calling `cc_test()` in each `CMakeLists.txt` will not `exactly` add test, but
# record all tests and its source files, the action of add tests is defered to HERE.
# Why doing so? since the target of `libpaddle.so` is mostly the last target, and
# the tests should be added after that accroding to dependency.
# (3) the tests links dynamic libraries, `libpaddle.so`
# (4) the tests are generated to the same directory, i.e., `CC_TESTS_DIR` defined above.
# Next, (to be discussed)
# (1) move all source files to same folder,
# (2) naturally, and configure tests in only one `CMakeLists.txt`,
# (3) cc tests support linking pre-built dynamic libraries. For example, use the dynamic
# library in the installed paddle by `pip`.
if(WITH_TESTING)
cinn_cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest gflags)
endif()
add_subdirectory(api)
add_subdirectory(auto_schedule)
add_subdirectory(common)
add_subdirectory(utils)
add_subdirectory(poly)
add_subdirectory(runtime)
add_subdirectory(ir)
add_subdirectory(backends)
add_subdirectory(lang)
add_subdirectory(optim)
add_subdirectory(hlir)
add_subdirectory(pybind)
add_subdirectory(frontend)
# Download a model
download_and_uncompress("${DOWNLOAD_MODEL_DIR}" "${PADDLE_RESOURCE_URL}"
"lite_naive_model.tar.gz")
core_gather_headers()
core_gather_headers()
gather_srcs(cinnapi_src SRCS op_node.cc tensor_node.cc)
message(STATUS "srcs: ${cinnapi_src}")
The classes in this directory are the interface of group fusion pass, you can use these apis to build the stragey for group fusion.
The Class and APIs are following:
`OpGroup` : A set of op nodes, which will pass to cinn backend for generating kernel code. Two groups can fuse togather according to the rule of merging written in the passes.
`OpNode` : Map the op in the program.
`TensorNode` : Map the tensor in the program.
`Shape` : The shape infomation of tensor
`FusePassCtx` : The context is the parameter for the pass, it hold the data all you need in the pass.
`FuseHelper` : We provide some util methods such as `DetectCycleIfFuse` in fuse_helper to simplify development of pass.
| Class | method | description |
| :--: | :--: | :--: |
| OpGroup | kind()| Get the Kind of group |
| | producers()| Get producer groups of current group |
| | consumers() | Get consumer groups of current group |
| | WalkOpNodes(const std::function<void(const OpNode&)>& VisitOpNode) | Visit the op_nodes in the group and execute the VisitOpNode function for each OpNode |
| | | |
| OpNode | kind() | Get the Kind of op_node |
| | inputs() | Get input tensors of op_node |
| | outputs() | Get output tensors of op_node |
| | GetAttr(const std::string& attr_name) | Get attribute of op_node by attr name |
| | | |
| TensorNode | shape() | Get shape of tensor |
| | producer() | Get the producer op_node of tensor |
| | consumers() | Get the consumer op_nodes of tensor |
| | | |
| Shape | numel() | Get total number of elements in the shape |
| | other methods are same with std::vector<int64_t> | |
| | | |
| LightwareFusePassCtx | PickOpGroup() | Get the current group in the pass context |
| | void EnableFuse(const OpGroup& first, const OpGroup& second) | Mark the two groups which can fuse togather |
| | fuse_helper() | Get the fuse_helper provided by pass context |
| | | |
| InputFusePassCtx | PickConsumersWithSameInputs() | Get all consumer groups for input tensors of graph |
| | void EnableFuse(const OpGroup& first, const OpGroup& second) | Mark the two groups which can fuse togather |
| | fuse_helper() | Get the fuse_helper provided by pass context |
| | | |
| FuseHelper | DetectCycleIfFuse(const OpGroup& first, const OpGroup& second) | Whether there is cycle in graph after fusing two groups |
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/cinn/api/op_node.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/pass/fusion_helper_base.h"
namespace cinn {
namespace api {
class OpGroup {
public:
explicit OpGroup(const std::shared_ptr<hlir::framework::Graph::Group>& group)
: group_(group) {}
OpGroup(const OpGroup& other) = default;
using Comparator = hlir::framework::Graph::Group::SharedGroupComparator;
using Hasher = hlir::framework::Graph::Group::SharedGroupHasher;
class OpGroupListIterator {
public:
OpGroupListIterator(
std::unordered_set<std::shared_ptr<hlir::framework::Graph::Group>,
Hasher,
Comparator>::const_iterator it)
: iter_(it) {}
OpGroupListIterator& operator++() {
++iter_;
return *this;
}
OpGroupListIterator operator++(int) {
OpGroupListIterator tmp = *this;
++iter_;
return tmp;
}
bool operator==(const OpGroupListIterator& other) const {
return iter_ == other.iter_;
}
bool operator!=(const OpGroupListIterator& other) const {
return !(*this == other);
}
OpGroup operator*() const { return OpGroup(*iter_); }
private:
std::unordered_set<std::shared_ptr<hlir::framework::Graph::Group>,
Hasher,
Comparator>::const_iterator iter_;
};
class ProducerOpGroupListView {
public:
ProducerOpGroupListView(
const std::weak_ptr<hlir::framework::Graph::Group>& group)
: group_(group) {}
ProducerOpGroupListView(const ProducerOpGroupListView& other) = delete;
ProducerOpGroupListView(ProducerOpGroupListView&& other) = delete;
ProducerOpGroupListView& operator=(const ProducerOpGroupListView& other) =
delete;
using const_iterator = OpGroupListIterator;
size_t size() const {
CHECK(group_.lock());
return group_.lock()->producer_groups().size();
}
const_iterator begin() const {
CHECK(group_.lock());
return const_iterator(group_.lock()->producer_groups().begin());
}
const_iterator end() const {
CHECK(group_.lock());
return const_iterator(group_.lock()->producer_groups().end());
}
private:
const std::weak_ptr<hlir::framework::Graph::Group> group_;
};
class ConsumerOpGroupListView {
public:
ConsumerOpGroupListView(
const std::weak_ptr<hlir::framework::Graph::Group>& group)
: group_(group) {}
ConsumerOpGroupListView(const ConsumerOpGroupListView& other) = delete;
ConsumerOpGroupListView(ConsumerOpGroupListView&& other) = delete;
ConsumerOpGroupListView& operator=(const ConsumerOpGroupListView& other) =
delete;
using const_iterator = OpGroupListIterator;
size_t size() const {
CHECK(group_.lock());
return group_.lock()->consumer_groups().size();
}
const_iterator begin() const {
CHECK(group_.lock());
return const_iterator(group_.lock()->consumer_groups().begin());
}
const_iterator end() const {
CHECK(group_.lock());
return const_iterator(group_.lock()->consumer_groups().end());
}
private:
const std::weak_ptr<hlir::framework::Graph::Group> group_;
};
const std::string& group_id() const { return group_.lock()->group_id; }
hlir::framework::OpPatternKind kind() const { return group_.lock()->kind(); }
// The WalkOpNodes function is used to traverse the op_nodes in the group and
// execute the VisitOpNode function for each OpNode. This function is
// equivalent to for loop for op_nodes in graph.
//
// In order to avoid unnecessary memory copies, we use WalkOpNodes function
// instead of providing a function to get all op_nodes directly.
//
// Example: Get the all Reduction op_nodes in the group.
// OpGroup group = ...;
// std::set<api::OpNode> reduce_ op_set;
// // The lambda funtion of VisitOpNode to get reduction op_nodes.
// auto get_reduce_op = [&reduce_op_set](const api::OpNode& op){
// if (op.kind() == OpPatternKind::kReduction) {
// reduce_op_set.insert(op);
// }
// };
// group.WalkOpNodes(get_reduce_op);
void WalkOpNodes(
const std::function<void(const OpNode&)>& VisitOpNode) const {
group_.lock()->WalkNodes([&](const hlir::framework::Node* node) {
VisitOpNode(OpNode(node, group_.lock()->graph_));
});
}
ProducerOpGroupListView producers() const {
return ProducerOpGroupListView(group_);
}
ConsumerOpGroupListView consumers() const {
return ConsumerOpGroupListView(group_);
}
std::shared_ptr<hlir::framework::Graph::Group> GetGroup() const {
return group_.lock();
}
bool operator==(const OpGroup& other) const {
return group_.lock().get() == other.group_.lock().get();
}
bool operator<(const OpGroup& other) const {
return group_.lock().get() < other.group_.lock().get();
}
private:
const std::weak_ptr<hlir::framework::Graph::Group> group_;
};
} // namespace api
} // namespace cinn
namespace std {
template <>
struct hash<cinn::api::OpGroup> {
size_t operator()(const cinn::api::OpGroup& obj) const {
return std::hash<size_t>()(reinterpret_cast<size_t>(obj.GetGroup().get()));
}
};
} // namespace std
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/api/op_node.h"
namespace cinn {
namespace api {
TensorNode OpNode::TensorListIterator::operator*() const {
return TensorNode(get_tensor_from_edge_(*iter_), graph_);
}
TensorNode OpNode::InputTensorListView::operator[](size_t index) const {
return TensorNode(
edges_[index]->source()->safe_as<hlir::framework::NodeData>(), graph_);
}
TensorNode OpNode::OutputTensorListView::operator[](size_t index) const {
return TensorNode(edges_[index]->sink()->safe_as<hlir::framework::NodeData>(),
graph_);
}
} // namespace api
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/cinn/api/tensor_node.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/pass/fusion_helper_base.h"
namespace cinn {
namespace api {
class OpNode {
public:
OpNode(const hlir::framework::Node* node, const hlir::framework::Graph* graph)
: node_(node),
graph_(graph),
input_tensors_(node->inlinks_in_order(), graph_),
output_tensors_(node->outlinks_in_order(), graph_) {}
OpNode(const OpNode& other)
: node_(other.node_),
graph_(other.graph_),
input_tensors_(node_->inlinks_in_order(), graph_),
output_tensors_(node_->outlinks_in_order(), graph_) {}
using OpPatternKind = cinn::hlir::framework::OpPatternKind;
OpPatternKind kind() const {
static const hlir::framework::OpValueType<OpPatternKind>& op_pattern_dict =
hlir::framework::Operator::GetAttrs<OpPatternKind>("OpPattern");
auto kind = op_pattern_dict[node_->op()];
if (kind == hlir::framework::kBroadcast) {
// As binary op was defined as broadcast, actually it should be
// element-wise.
if (node_->op()->name != "broadcast_to") {
return hlir::framework::kElementWise;
}
}
return kind;
}
class TensorListIterator {
public:
TensorListIterator(
std::vector<common::Shared<common::GraphEdge>>::const_iterator it,
const hlir::framework::Graph* graph,
std::function<hlir::framework::NodeData*(
common::Shared<common::GraphEdge>)> get_tensor_from_edge)
: iter_(it),
graph_(graph),
get_tensor_from_edge_(get_tensor_from_edge) {}
TensorListIterator& operator++() {
++iter_;
return *this;
}
TensorListIterator operator++(int) {
TensorListIterator tmp = *this;
++iter_;
return tmp;
}
bool operator==(const TensorListIterator& other) const {
return iter_ == other.iter_;
}
bool operator!=(const TensorListIterator& other) const {
return !(*this == other);
}
TensorNode operator*() const;
private:
std::vector<common::Shared<common::GraphEdge>>::const_iterator iter_;
const hlir::framework::Graph* graph_;
std::function<hlir::framework::NodeData*(common::Shared<common::GraphEdge>)>
get_tensor_from_edge_;
};
using const_iterator = TensorListIterator;
class InputTensorListView {
public:
InputTensorListView(
const std::vector<common::Shared<common::GraphEdge>>& edges,
const hlir::framework::Graph* graph)
: edges_(edges), graph_(graph) {}
InputTensorListView(const InputTensorListView& other) = delete;
InputTensorListView(InputTensorListView&& other) = delete;
InputTensorListView& operator=(const InputTensorListView& other) = delete;
size_t size() const { return edges_.size(); }
TensorNode operator[](size_t index) const;
const_iterator begin() const {
return const_iterator(
edges_.begin(), graph_, [](common::Shared<common::GraphEdge> edge) {
return edge->source()->safe_as<hlir::framework::NodeData>();
});
}
const_iterator end() const {
return const_iterator(
edges_.end(), graph_, [](common::Shared<common::GraphEdge> edge) {
return edge->source()->safe_as<hlir::framework::NodeData>();
});
}
private:
std::vector<common::Shared<common::GraphEdge>> edges_;
const hlir::framework::Graph* graph_;
};
class OutputTensorListView {
public:
OutputTensorListView(
const std::vector<common::Shared<common::GraphEdge>>& edges,
const hlir::framework::Graph* graph)
: edges_(edges), graph_(graph) {}
OutputTensorListView(const OutputTensorListView& other) = delete;
OutputTensorListView(OutputTensorListView&& other) = delete;
OutputTensorListView& operator=(const OutputTensorListView& other) = delete;
size_t size() const { return edges_.size(); }
TensorNode operator[](size_t index) const;
const_iterator begin() const {
return const_iterator(
edges_.begin(), graph_, [](common::Shared<common::GraphEdge> edge) {
return edge->sink()->safe_as<hlir::framework::NodeData>();
});
}
const_iterator end() const {
return const_iterator(
edges_.end(), graph_, [](common::Shared<common::GraphEdge> edge) {
return edge->sink()->safe_as<hlir::framework::NodeData>();
});
}
private:
std::vector<common::Shared<common::GraphEdge>> edges_;
const hlir::framework::Graph* graph_;
};
bool operator==(const OpNode& other) const { return node_ == other.node_; }
bool operator<(const OpNode& other) const { return node_ < other.node_; }
const InputTensorListView& inputs() const { return input_tensors_; }
const OutputTensorListView& outputs() const { return output_tensors_; }
template <typename T>
const T& GetAttr(const std::string& attr_name) const {
return absl::get<T>(GetAttr(attr_name));
}
private:
using Attribute = cinn::utils::Attribute;
const Attribute& GetAttr(const std::string& attr_name) const {
return node_->attrs.attr_store.at(attr_name);
}
friend struct std::hash<OpNode>;
const hlir::framework::Node* node_;
const hlir::framework::Graph* graph_;
const InputTensorListView input_tensors_;
const OutputTensorListView output_tensors_;
};
} // namespace api
} // namespace cinn
namespace std {
template <>
struct hash<cinn::api::OpNode> {
size_t operator()(const cinn::api::OpNode& obj) const {
return std::hash<size_t>()(reinterpret_cast<size_t>(obj.node_));
}
};
} // namespace std
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/pass/fusion_helper_base.h"
#include "paddle/cinn/utils/small_vector.h"
#include "paddle/cinn/utils/type_defs.h"
namespace cinn {
namespace api {
class Shape final {
public:
explicit Shape(const utils::ShapeType& shape)
: shape_(shape.begin(), shape.end()) {}
Shape(const Shape& other) = delete;
Shape(Shape&& other) = delete;
Shape& operator=(const Shape& other) = delete;
bool operator==(const Shape& other) const { return shape_ == other.shape_; }
size_t operator[](size_t index) const { return shape_[index]; }
size_t at(size_t index) const { return shape_[index]; }
size_t size() const { return shape_.size(); }
// Returns the total number of elements in the shape.
size_t numel() const {
return std::accumulate(
shape_.begin(), shape_.end(), 1, std::multiplies<int>());
}
private:
cinn::utils::SmallVector<int64_t, 12> shape_;
};
} // namespace api
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/api/tensor_node.h"
#include "paddle/cinn/api/op_node.h"
namespace cinn {
namespace api {
OpNode TensorNode::producer() const {
return OpNode(node_data_->source_node.get(), graph_);
}
OpNode TensorNode::ConsumerOpListView::Iterator::operator*() const {
return OpNode((*iter_)->sink()->safe_as<hlir::framework::Node>(), graph_);
}
} // namespace api
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/cinn/api/shape.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/pass/fusion_helper_base.h"
#include "paddle/cinn/utils/small_vector.h"
#include "paddle/cinn/utils/type_defs.h"
namespace cinn {
namespace api {
class OpNode;
class TensorNode final {
public:
TensorNode(const hlir::framework::NodeData* node_data,
const hlir::framework::Graph* graph)
: node_data_(node_data),
graph_(graph),
consumers_(node_data_->outlinks(), graph_) {
const auto& shape_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, utils::ShapeType>>(
"infershape");
CHECK(shape_dict.count(node_data_->id()))
<< "Can't find " << node_data_->id() << " 's shape!";
shape_ = std::make_shared<Shape>(shape_dict.find(node_data_->id())->second);
}
// Get the shape of tensor.
const Shape& shape() const { return *shape_; }
// Input data has no producer.
bool HasProducer() const { return node_data_->source_node.get() != nullptr; }
OpNode producer() const;
class ConsumerOpListView {
public:
ConsumerOpListView(const std::set<common::Shared<common::GraphEdge>,
common::GraphEdgeCompare>& edges,
const hlir::framework::Graph* graph)
: edges_(edges), graph_(graph) {}
ConsumerOpListView(const ConsumerOpListView& other) = delete;
ConsumerOpListView(ConsumerOpListView&& other) = delete;
ConsumerOpListView& operator=(const ConsumerOpListView& other) = delete;
class Iterator {
public:
Iterator(std::set<common::Shared<common::GraphEdge>,
common::GraphEdgeCompare>::const_iterator it,
const hlir::framework::Graph* graph)
: iter_(it), graph_(graph) {}
Iterator& operator++() {
++iter_;
return *this;
}
Iterator operator++(int) {
Iterator tmp = *this;
++iter_;
return tmp;
}
bool operator==(const Iterator& other) const {
return iter_ == other.iter_;
}
bool operator!=(const Iterator& other) const { return !(*this == other); }
OpNode operator*() const;
private:
std::set<common::Shared<common::GraphEdge>,
common::GraphEdgeCompare>::const_iterator iter_;
const hlir::framework::Graph* graph_;
};
size_t size() const { return edges_.size(); }
Iterator begin() const { return Iterator(this->edges_.begin(), graph_); }
Iterator end() const { return Iterator(this->edges_.end(), graph_); }
private:
const std::set<Shared<common::GraphEdge>, common::GraphEdgeCompare>& edges_;
const hlir::framework::Graph* graph_;
};
const ConsumerOpListView& consumers() const { return consumers_; }
private:
const hlir::framework::NodeData* node_data_;
const hlir::framework::Graph* graph_;
std::shared_ptr<Shape> shape_;
const ConsumerOpListView consumers_;
};
} // namespace api
} // namespace cinn
add_subdirectory(analysis)
add_subdirectory(cost_model)
add_subdirectory(database)
add_subdirectory(measure)
add_subdirectory(post_schedule_rule)
add_subdirectory(search_space)
add_subdirectory(search_strategy)
add_subdirectory(task)
add_subdirectory(task_scheduler)
add_subdirectory(tests)
cinn_proto_library(auto_schedule_proto SRCS auto_schedule.proto DEPS
schedule_desc_proto)
core_gather_headers()
gather_srcs(cinnapi_src SRCS auto_tuner.cc)
#cinn_cc_test(test_auto_tuner SRCS auto_tuner_test.cc DEPS cinncore)
foreach(header ${auto_schedule_proto_HDRS})
set(core_proto_includes
"${core_proto_includes};${header}"
CACHE INTERNAL "")
endforeach()
core_gather_headers()
gather_srcs(cinnapi_src SRCS analyze_ir.cc)
cinn_cc_test(test_analyze_ir SRCS analyze_ir_test.cc DEPS cinncore)
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include <glog/logging.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
namespace cinn {
namespace auto_schedule {
std::vector<ir::Var> IndicesToVars(const std::vector<ir::Expr>& indices) {
std::vector<ir::Var> result;
for (const ir::Expr& e : indices) {
// Whether we have to convert other types, like const numbers to Var?
if (e.As<ir::_Var_>() != nullptr) {
ir::Expr copy_e = optim::IRCopy(e);
ir::_Var_* var_ref = copy_e.As<ir::_Var_>();
result.emplace_back(ir::Var(var_ref));
}
}
return result;
}
void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) {
if (!sche_block->read_buffers.empty() || !sche_block->write_buffers.empty()) {
return;
}
ir::CollectIRNodesWithoutTensor(sche_block->body, [&](const Expr* x) {
const ir::Load* load_expr = x->As<ir::Load>();
if (load_expr != nullptr) {
const ir::Tensor t = load_expr->tensor.as_tensor_ref();
sche_block->read_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices)));
return false;
}
const ir::Store* store_expr = x->As<ir::Store>();
if (store_expr != nullptr) {
const ir::Tensor t = store_expr->tensor.as_tensor_ref();
sche_block->write_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices)));
return false;
}
return false;
});
}
bool ContainsNodeType(ir::Expr expr,
const std::unordered_set<ir::IrNodeTy>& node_types) {
std::set<ir::Expr> collection =
ir::CollectIRNodesWithoutTensor(expr, [&](const Expr* x) {
return node_types.find(x->node_type()) != node_types.end();
});
return !collection.empty();
}
std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(
const std::vector<ir::LoweredFunc>& lowered_funcs) {
std::unordered_set<std::string> result;
for (const ir::LoweredFunc& func : lowered_funcs) {
for (const ir::Argument& arg : func->args) {
if (arg.is_output()) {
result.insert(arg.name());
}
}
}
return result;
}
bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) {
const ir::ScheduleBlock* sche_block =
sche_block_realize.schedule_block.As<ir::ScheduleBlock>();
if (sche_block->write_buffers.size() != 1 ||
sche_block->read_buffers.empty()) {
return false;
}
const ir::Expr& write_buffer =
sche_block->write_buffers[0].As<ir::_BufferRange_>()->buffer;
// Enumerate each read region, get the number of schedule block iter vars
// which are not used to index the read region
int total_unused_iter_vars = 0;
for (const ir::Expr& read_buffer_expr : sche_block->read_buffers) {
const ir::_BufferRange_* read_buffer =
read_buffer_expr.As<ir::_BufferRange_>();
// Skip the reduction buffer
if (read_buffer->buffer == write_buffer) {
continue;
}
// Collect the vars in schedule block that are used to index the read region
std::unordered_set<std::string> vars_index_read;
for (const Var& range : read_buffer->ranges) {
vars_index_read.insert(range->name);
}
// Check the block iter vars are not used to index the read region
int n_unused_block_vars = 0;
for (const ir::Var& block_iter_var : sche_block->iter_vars) {
if (!block_iter_var->is_reduce_axis) {
bool iter_var_in_read = false;
for (const std::string& var : vars_index_read) {
if (var == block_iter_var->name) {
iter_var_in_read = true;
break;
}
}
if (!iter_var_in_read) {
++n_unused_block_vars;
}
}
}
total_unused_iter_vars += n_unused_block_vars;
}
return total_unused_iter_vars >= 1;
}
ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
const ir::LoweredFunc& old_func,
ir::Expr& body) { // NOLINT
ir::ModuleExpr mod_expr(std::vector<ir::Expr>({body}));
ir::IRSchedule ir_sch(mod_expr);
// temp_bufs may be deleted during auto tuning (such as auto inline),
// we have to check from old temp bufs and set them as local buffer.
for (const ir::Buffer& buf : old_func->temp_bufs) {
const std::string& buf_name = buf->name;
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
for (ir::Expr& e : all_block_realizes) {
const ir::ScheduleBlockRealize* sche_block_realize =
e.As<ir::ScheduleBlockRealize>();
const std::string& sche_name =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>()->name;
if (buf_name == "_" + sche_name) {
VLOG(6) << "Set local buffer for temp buffer " << buf_name;
ir_sch.SetBuffer(e, "local", true);
break;
}
}
}
ir::Expr updated_body = ir_sch.GetModule().GetExprs()[0];
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&updated_body);
#endif
// Get new temp bufs by analyzing.
std::vector<ir::Buffer> new_temp_bufs =
lang::GetTempBuffers(old_func->args, updated_body);
ir::LoweredFunc new_func = ir::_LoweredFunc_::Make(
old_func->name, old_func->args, updated_body, new_temp_bufs);
#ifdef CINN_WITH_CUDA
if (target == common::DefaultNVGPUTarget()) {
new_func->PrepareCudaAxisInfoFromBody();
}
#endif
new_func =
optim::Optimize(Expr(new_func), target, false).as_lowered_func_ref();
new_func->PrepareBufferCastExprs(/*with_expr_gen_tensor = */ false);
return new_func;
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block);
bool ContainsNodeType(ir::Expr expr,
const std::unordered_set<ir::IrNodeTy>& node_types);
/**
* Collects all input lowered_funcs and return names of all output arguments
*/
std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(
const std::vector<ir::LoweredFunc>& lowered_funcs);
/**
* Determine whether a schedule block needs multileveltiling
*/
bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize);
/**
* Update a LoweredFunc by regenerating related fields with a new function body
*/
ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
const ir::LoweredFunc& old_func,
ir::Expr& body); // NOLINT
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <sstream>
#include <vector>
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace auto_schedule {
TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) {
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
ir::Expr M(32);
ir::Expr N(32);
lang::Placeholder<float> A("A", {M, N});
ir::Tensor B = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ASSERT_FALSE(funcs.empty());
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Analyzing for Expr:";
VLOG(6) << ast_expr;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
ASSERT_EQ(all_block_realizes.size(), 1UL);
ir::ScheduleBlockRealize* sche_block_realize =
all_block_realizes[0].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
AnalyzeScheduleBlockReadWriteBuffer(sche_block);
/*
* the sche_block_realize will be:
* ScheduleBlock(B)
* {
* i0, i1 = axis.bind(i, j)
* read_buffers(_A[i0(undefined:undefined), i1(undefined:undefined)])
* write_buffers(_B[i0(undefined:undefined), i1(undefined:undefined)])
* B[i0, i1] = A[i0, i1]
* }
*/
VLOG(6) << "ScheduleBlockRealize: ";
VLOG(6) << all_block_realizes[0];
ASSERT_EQ(sche_block->read_buffers.size(), 1UL);
std::stringstream read_ss;
read_ss << sche_block->read_buffers[0];
ASSERT_EQ(read_ss.str(), "_A[i0(0:32), i1(0:32)]");
ASSERT_EQ(sche_block->write_buffers.size(), 1UL);
std::stringstream write_ss;
write_ss << sche_block->write_buffers[0];
ASSERT_EQ(write_ss.str(), "_B[i0(0:32), i1(0:32)]");
}
TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) {
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
ir::Expr M(32);
ir::Expr N(128);
lang::Placeholder<float> A("A", {M});
lang::Placeholder<float> B("B", {N});
ir::Tensor C = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = poly::CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"AddDiffShape", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: ";
VLOG(6) << ast_expr;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
ASSERT_EQ(all_block_realizes.size(), 1UL);
ir::ScheduleBlockRealize* sche_block_realize =
all_block_realizes[0].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
AnalyzeScheduleBlockReadWriteBuffer(sche_block);
VLOG(6) << "ScheduleBlockRealize: ";
VLOG(6) << all_block_realizes[0];
ASSERT_EQ(sche_block->read_buffers.size(), 2UL);
std::vector<std::string> expect_read = {"_A[i0(0:32)]", "_B[i1(0:128)]"};
ASSERT_EQ(sche_block->read_buffers.size(), expect_read.size());
for (size_t i = 0; i < expect_read.size(); ++i) {
std::stringstream read_ss;
read_ss << sche_block->read_buffers[i];
ASSERT_EQ(read_ss.str(), expect_read[i]);
}
ASSERT_EQ(sche_block->write_buffers.size(), 1UL);
std::stringstream write_ss;
write_ss << sche_block->write_buffers[0];
ASSERT_EQ(write_ss.str(), "_C[i0(0:32), i1(0:128)]");
}
TEST(AnalyzeIr, ContainsNodeType) {
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
ir::Expr M(32);
ir::Expr N(32);
lang::Placeholder<float> A("A", {M, N});
ir::Tensor B = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ASSERT_FALSE(funcs.empty());
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Analyzing for Expr:";
VLOG(6) << ast_expr;
ASSERT_TRUE(
ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::Store}));
ASSERT_TRUE(ContainsNodeType(ast_expr,
{ir::IrNodeTy::Load, ir::IrNodeTy::IfThenElse}));
ASSERT_FALSE(ContainsNodeType(ast_expr,
{ir::IrNodeTy::IfThenElse, ir::IrNodeTy::Sum}));
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax ="proto3";
package cinn.auto_schedule.proto;
import "paddle/cinn/ir/schedule/schedule_desc.proto";
message TuningRecord {
string task_key = 1;
double execution_cost = 2;
double predicted_cost = 3;
cinn.ir.proto.ScheduleDesc trace = 4;
}
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/auto_tuner.h"
#include <glog/logging.h>
#include <pybind11/embed.h>
#include <algorithm>
#include <memory>
#include <utility>
#include "paddle/cinn/auto_schedule/database/jsonfile_database.h"
#include "paddle/cinn/auto_schedule/measure/schedule_measurer.h"
#include "paddle/cinn/auto_schedule/measure/simple_builder.h"
#include "paddle/cinn/auto_schedule/measure/simple_runner.h"
#include "paddle/cinn/auto_schedule/task/task_creator.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/auto_schedule/task_scheduler/task_scheduler.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace auto_schedule {
AutoTuner::AutoTuner(const common::Target& target,
hlir::framework::Graph* graph)
: target_(target), graph_(graph) {}
void AutoTuner::Initialize(const Config& config,
hlir::framework::GraphCompiler* graph_compiler) {
// create builder, runner, and schedule measurer
builder_ = std::make_unique<SimpleBuilder>(graph_compiler);
runner_ = std::make_unique<SimpleRunner>(config.runner_repeat_times);
schedule_measurer_ =
std::make_unique<ScheduleMeasurer>(builder_.get(), runner_.get());
// initialize database
database_ = std::move(Database::Make(config.database_config));
// create tasks
TaskCreator task_creator;
tasks_ = task_creator.CreateTuneTaskOpLevel(graph_);
const auto& dtype_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
const auto& shape_dict = graph_->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
op_lowerer_ = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target_);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto i = 0; i < tasks_.size(); ++i) {
auto&& task = tasks_[i];
task.Initialize(shape_dict, dtype_dict, op_lowerer_.get());
// Register the initial ModuleExpr corresponding to the task
task_registry->Regist(task.serialized_key,
ir::ModuleExpr(task.GetLoweredFuncBodyExprs()));
VLOG(3) << "Add a task, id:" << i << ", serialized_key:\n"
<< task.serialized_key;
}
// create task optimizers
utils::LinearRandomEngine::StateType initial_seed =
utils::LinearRandomEngine::GetDeviceRandomValue();
task_optimizers_.resize(tasks_.size());
std::transform(tasks_.begin(),
tasks_.end(),
task_optimizers_.begin(),
[&](TuneTask& task) {
return std::make_unique<TaskOptimizer>(
&task,
schedule_measurer_.get(),
database_.get(),
utils::ForkRandomState(&initial_seed));
});
// create task scheduler
task_scheduler_ = TaskScheduler::Make(
tasks_, config.task_schedule_config, config.task_schedule_strategy);
}
void PrintResult(std::shared_ptr<hlir::framework::Graph::Group> group) {
if (!VLOG_IS_ON(3)) {
return;
}
auto nodes = group->CollectNodes();
VLOG(3) << "Node size:" << nodes.size();
VLOG(3) << "Group {";
for (auto* node : nodes) {
VLOG(3) << " " << hlir::framework::DebugString(node);
}
VLOG(3) << "}";
}
void PrintResult(const FunctionGroup& functions) {
if (!VLOG_IS_ON(3)) {
return;
}
VLOG(3) << "Function size:" << functions.size();
for (auto i = 0; i < functions.size(); ++i) {
const ir::LoweredFunc& func = functions.at(i);
VLOG(3) << "LoweredFunc-" << i << " detail:\n" << func;
}
}
void PrintResult(const TuningResult& result) {
if (!VLOG_IS_ON(3)) {
return;
}
VLOG(3) << "###### Debug TuningResult ######\n";
VLOG(3) << "Tuned SubGraph num:" << result.subgraphs.size();
for (auto i = 0; i < result.subgraphs.size(); ++i) {
VLOG(3) << "****** SubGraph-" << i << " Detail ******\n";
PrintResult(result.subgraphs.at(i));
VLOG(3) << "****** SubGraph End ******";
}
VLOG(3) << "Tuned FunctionGroup num:" << result.function_groups.size();
for (auto i = 0; i < result.function_groups.size(); ++i) {
VLOG(3) << "****** FunctionGroup-" << i << " Detail ******\n";
PrintResult(result.function_groups.at(i));
VLOG(3) << "****** FunctionGroup End ******";
}
VLOG(3) << "###### TuningResult End ######";
}
TuningResult AutoTuner::Tune(const TuningOptions& options) {
CHECK_GT(options.num_tuning_rounds, 0) << "Invalid config";
VLOG(3) << "Begin tuning with round num=" << options.num_tuning_rounds
<< ", tasks size=" << tasks_.size();
TuningResult result;
result.subgraphs.resize(tasks_.size());
result.function_groups.resize(tasks_.size());
// A task only tunes schedule now, so we populate its sub_graph
// as default result of graph tuning, and that should be updated
// once we support graph tuning.
for (auto i = 0; i < tasks_.size(); ++i) {
auto&& task = tasks_.at(i);
result.subgraphs[i] = task.subgraph;
}
for (int r = 0; r < options.num_tuning_rounds; ++r) {
VLOG(3) << "<<<<<< Round " << r << " >>>>>>";
int run_id = -1;
task_scheduler_->Reset();
while ((run_id = task_scheduler_->NextTaskId()) != -1) {
VLOG(3) << "Start tuning Task-" << run_id;
auto* opt = task_optimizers_.at(run_id).get();
auto function_group = opt->Optimize(options);
VLOG(3) << "Task-" << run_id << " finished, print optimized functions:\n";
PrintResult(function_group);
// update the best schedules searched so far.
result.function_groups.at(run_id) = std::move(function_group);
}
}
PrintResult(result);
return result;
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/cinn/auto_schedule/measure/schedule_measurer.h"
#include "paddle/cinn/auto_schedule/task/task_optimizer.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/auto_schedule/task_scheduler/task_scheduler.h"
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
namespace cinn {
namespace auto_schedule {
// This class is entrance of auto-tune, users can use it
// to tune graph (not supported yet) and search a series of schedules
// that maybe more likely to obtain better performance.
// Internally, it creates necessary components and use them to perform tuning.
class AutoTuner {
public:
// configure how to perform auto-tune, such as
// the way to create tasks, the strategy of scheduling tasks and so on.
struct Config {
std::string task_schedule_strategy = "round_robin";
TaskScheduler::Config task_schedule_config;
int runner_repeat_times = 1;
DatabaseConfig database_config;
};
AutoTuner(const common::Target& target, hlir::framework::Graph* graph);
// Initialize tuner with specific config and auxiliary objects.
void Initialize(const Config& config,
hlir::framework::GraphCompiler* graph_compiler);
// Perform the tuning process and return the final result
TuningResult Tune(const TuningOptions& options);
private:
const common::Target& target_;
hlir::framework::Graph* graph_;
std::unique_ptr<hlir::framework::OpLowerer> op_lowerer_;
// Tasks to tune
std::vector<TuneTask> tasks_;
// Scheduler that select a task to tune at every turn.
std::unique_ptr<TaskScheduler> task_scheduler_;
// The actor to perform auto-tune, each optimizer take a task.
std::vector<std::unique_ptr<TaskOptimizer>> task_optimizers_;
// Classes used to measure AutoTune samples
std::unique_ptr<ScheduleBuilder> builder_;
std::unique_ptr<ScheduleRunner> runner_;
std::unique_ptr<ScheduleMeasurer> schedule_measurer_;
// The database to store tuning record
std::unique_ptr<Database> database_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/auto_tuner.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstdlib>
#include <iostream>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn {
namespace auto_schedule {
using ::cinn::hlir::framework::BuildScope;
using ::cinn::hlir::framework::Graph;
using ::cinn::hlir::framework::GraphCompiler;
using ::cinn::hlir::framework::Instruction;
using ::cinn::hlir::framework::Node;
using ::cinn::hlir::framework::Scope;
class TestAutoTuner : public ::testing::Test {
public:
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
std::shared_ptr<Graph> graph;
std::shared_ptr<Scope> compiled_scope;
std::unique_ptr<GraphCompiler> graph_compiler;
std::unique_ptr<AutoTuner> tuner;
frontend::Program CreateAddReluProgram() {
frontend::NetBuilder builder("test");
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
return builder.Build();
}
void SetUp() override {
srand(0);
std::unordered_set<std::string> fetch_ids;
auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
compiled_scope = BuildScope(target, graph);
graph_compiler =
std::make_unique<GraphCompiler>(target, compiled_scope, graph);
tuner = std::make_unique<AutoTuner>(target, graph.get());
}
TuningResult InitializeAndTune(const AutoTuner::Config& config,
const TuningOptions& options) {
tuner->Initialize(config, graph_compiler.get());
return tuner->Tune(options);
}
virtual void BasicCheckResult(const TuningResult& result) {
ASSERT_EQ(1, result.subgraphs.size());
auto nodes = result.subgraphs.front()->CollectNodes();
ASSERT_EQ(nodes.size(), 4UL);
ASSERT_EQ(nodes[0]->op()->name, "broadcast_to");
ASSERT_EQ(nodes[1]->op()->name, "fill_constant");
ASSERT_EQ(nodes[2]->op()->name, "elementwise_add");
ASSERT_EQ(nodes[3]->op()->name, "max");
ASSERT_EQ(result.function_groups.size(), 1UL);
ASSERT_EQ(result.function_groups[0].size(), 1UL);
}
virtual void ApplyTunedAndRun(const TuningResult& result) {
// build runtime program with tuning result
GraphCompiler::CompileOptions compile_options;
compile_options.with_instantiate_variables = true;
compile_options.Apply(result);
ASSERT_EQ(1, compile_options.groups.size());
ASSERT_EQ(1, compile_options.lowered_funcs.size());
VLOG(6) << "Print lowered_funcs before building";
VLOG(6) << compile_options.lowered_funcs[0][0];
VLOG(6) << compile_options.lowered_funcs[1][0];
auto runtime_program =
graph_compiler->Build(compile_options).runtime_program;
ASSERT_EQ(1, runtime_program->size());
runtime_program->Execute();
}
void ZeroMeasure() {
// set config and options
AutoTuner::Config tuning_config;
tuning_config.task_schedule_strategy = "round_robin";
TuningOptions tuning_options;
tuning_options.num_measure_trials = 0;
auto result = InitializeAndTune(tuning_config, tuning_options);
BasicCheckResult(result);
ApplyTunedAndRun(result);
}
void NonZeroMeasure() {
// set config and options
AutoTuner::Config tuning_config;
tuning_config.task_schedule_strategy = "round_robin";
TuningOptions tuning_options;
tuning_options.num_measure_trials = 4;
tuning_options.num_samples_per_iteration = 2;
auto result = InitializeAndTune(tuning_config, tuning_options);
BasicCheckResult(result);
ApplyTunedAndRun(result);
}
};
TEST_F(TestAutoTuner, ZeroMeasure_DisableCostModel) {
FLAGS_auto_schedule_use_cost_model = false;
ZeroMeasure();
}
TEST_F(TestAutoTuner, ZeroMeasure_EnableCostModel) {
FLAGS_auto_schedule_use_cost_model = true;
ZeroMeasure();
}
TEST_F(TestAutoTuner, NonZeroMeasure_DisableCostModel) {
FLAGS_auto_schedule_use_cost_model = false;
NonZeroMeasure();
}
TEST_F(TestAutoTuner, NonZeroMeasure_EnableCostModel) {
FLAGS_auto_schedule_use_cost_model = true;
NonZeroMeasure();
}
} // namespace auto_schedule
} // namespace cinn
core_gather_headers()
gather_srcs(cinnapi_src SRCS xgb_cost_model.cc expr_cost_model.cc feature.cc
feature_extractor.cc)
cinn_cc_test(test_xgb_cost_model SRCS xgb_cost_model_test.cc DEPS cinncore)
cinn_cc_test(test_feature_extractor SRCS feature_extractor_test.cc DEPS
cinncore)
cinn_cc_test(test_feature SRCS feature_test.cc DEPS cinncore)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment