Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -19,7 +19,7 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace cinn {
namespace common {
......
......@@ -78,7 +78,7 @@ std::string NameGenerator::New(const std::string& name_hint) {
} // namespace common
DEFINE_bool(cinn_runtime_display_debug_info,
PD_DEFINE_bool(cinn_runtime_display_debug_info,
false,
"Whether to display debug information in runtime");
} // namespace cinn
......@@ -14,7 +14,6 @@
#pragma once
#include <absl/types/any.h>
#include <gflags/gflags.h>
#include <isl/cpp.h>
#include <mutex>
......@@ -25,10 +24,11 @@
#include "paddle/cinn/common/debug_manager.h"
#include "paddle/cinn/common/info_registry.h"
#include "paddle/cinn/common/target.h"
#include "paddle/utils/flags.h"
namespace cinn {
DECLARE_bool(cinn_runtime_display_debug_info);
PD_DECLARE_bool(cinn_runtime_display_debug_info);
namespace ir {
class Expr;
......@@ -52,6 +52,22 @@ struct NameGenerator {
mutable std::mutex mutex_;
};
struct PrettyNamer {
const std::string& GetOrNew(const size_t hash_key,
const std::string& name_hint) {
if (pretty_names_.find(hash_key) == pretty_names_.end()) {
pretty_names_[hash_key] = name_generator_.New(name_hint);
}
return pretty_names_.at(hash_key);
}
NameGenerator& GetNameGenerator() { return name_generator_; }
private:
absl::flat_hash_map<size_t, std::string> pretty_names_;
NameGenerator name_generator_;
};
class Context {
public:
static Context& Global();
......@@ -61,10 +77,15 @@ class Context {
* @param name_hint The prefix.
*/
std::string NewName(const std::string& name_hint) {
return name_generator_.New(name_hint);
return pretty_namer_.GetNameGenerator().New(name_hint);
}
void ResetNameId() { name_generator_.ResetID(); }
std::string PrettyUniqName(const size_t hash_key,
const std::string& name_hint) {
return pretty_namer_.GetOrNew(hash_key, name_hint);
}
void ResetNameId() { pretty_namer_.GetNameGenerator().ResetID(); }
const std::vector<std::string>& runtime_include_dir();
......@@ -82,7 +103,7 @@ class Context {
private:
Context() = default;
NameGenerator name_generator_;
PrettyNamer pretty_namer_;
std::vector<std::string> runtime_include_dir_;
mutable std::mutex mutex_;
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace cinn {
namespace common {
class DevInfoBase {
public:
explicit DevInfoBase(int device_num = 0) : device_num_(device_num) {}
virtual ~DevInfoBase() = default;
protected:
int device_num_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/cinn/common/dev_info_base.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/common/nvgpu_dev_info.h"
#include "paddle/cinn/common/target.h"
namespace cinn {
namespace common {
template <Target::Arch arch>
struct GetDevType {
using DevType = DevInfoBase;
};
// Extra device should be added here
class NVGPUDevInfo;
template <>
struct GetDevType<Target::Arch::NVGPU> {
using DevType = NVGPUDevInfo;
};
template <Target::Arch arch>
class DevInfoMgr final {
private:
explicit DevInfoMgr(int device_num = 0) : device_num_(device_num) {
impl_ = std::make_unique<typename GetDevType<arch>::DevType>(device_num);
}
std::unique_ptr<DevInfoBase> impl_;
int device_num_;
public:
static DevInfoMgr<arch> GetDevInfo(int device_num = 0) {
return DevInfoMgr(device_num);
}
using RetType = typename GetDevType<arch>::DevType;
const RetType* operator->() const {
CHECK(!std::is_void<RetType>()) << "Current device can't be recognized!\n";
return dynamic_cast<const RetType*>(impl_.get());
}
RetType* operator->() {
CHECK(!std::is_void<RetType>()) << "Current device can't be recognized!\n";
return dynamic_cast<RetType*>(impl_.get());
}
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <functional>
#include <iostream>
#include <queue>
#include <tuple>
#include <unordered_set>
#include <vector>
#include "paddle/cinn/common/bfs_walker.h"
namespace cinn {
template <typename VT, typename FT>
class EquationGraphTopoWalker final {
public:
using VariableVisitorT = std::function<void(VT)>;
using FunctionVisitorT = std::function<void(FT)>;
using F4VVisitor = std::function<void(VT, const FunctionVisitorT&)>;
using V4FVisitor = std::function<void(FT, const VariableVisitorT&)>;
EquationGraphTopoWalker(const F4VVisitor& NextFunctionsVisitor,
const V4FVisitor& InputVariablesVisitor,
const V4FVisitor& OutputVariablesVisitor)
: VisitNextFunctions(NextFunctionsVisitor),
VisitInputVariables(InputVariablesVisitor),
VisitOutputVariables(OutputVariablesVisitor) {}
~EquationGraphTopoWalker() = default;
static F4VVisitor Merge(const F4VVisitor& lhs, const F4VVisitor& rhs) {
return [=](VT variable, const FunctionVisitorT& Visit) {
lhs(variable, Visit);
rhs(variable, Visit);
};
}
static V4FVisitor Merge(const V4FVisitor& lhs, const V4FVisitor& rhs) {
return [=](FT function, const VariableVisitorT& Visit) {
lhs(function, Visit);
rhs(function, Visit);
};
}
EquationGraphTopoWalker Merge(const EquationGraphTopoWalker& that) const {
return {Merge(this->VisitNextFunctions, that.VisitNextFunctions),
Merge(this->VisitInputVariables, that.VisitInputVariables),
Merge(this->VisitOutputVariables, that.VisitOutputVariables)};
}
void WalkVariable(VT start, const VariableVisitorT& VariableVisitor) const {
std::array<VT, 1> starts{start};
(*this)(starts.begin(), starts.end(), VariableVisitor, [&](FT) {});
}
template <typename VarIterT>
void WalkVariable(VarIterT begin,
VarIterT end,
const VariableVisitorT& VariableVisitor) const {
(*this)(begin, end, VariableVisitor, [&](FT) {});
}
void WalkFunction(VT start, const FunctionVisitorT& FunctionVisitor) const {
std::array<VT, 1> starts{start};
(*this)(
starts.begin(), starts.end(), [&](VT) {}, FunctionVisitor);
}
template <typename VarIterT>
void WalkFunction(VarIterT begin,
VarIterT end,
const FunctionVisitorT& FunctionVisitor) const {
(*this)(
begin, end, [&](VT) {}, FunctionVisitor);
}
void BfsWalkFunction(VT variable,
const FunctionVisitorT& FunctionVisitor) const {
std::array<VT, 1> array{variable};
BfsWalkFunction(array.begin(), array.end(), FunctionVisitor);
}
template <typename VarIterT>
void BfsWalkFunction(VarIterT begin,
VarIterT end,
const FunctionVisitorT& FunctionVisitor) const {
using F4FVisitor = std::function<void(FT, const FunctionVisitorT&)>;
F4FVisitor BfsVisitNextFunction = [&](FT f,
const FunctionVisitorT& DoEach) {
VisitInputVariables(
f, [&](VT variable) { VisitNextFunctions(variable, DoEach); });
VisitOutputVariables(
f, [&](VT variable) { VisitNextFunctions(variable, DoEach); });
};
std::vector<FT> starts{};
for (VarIterT iter = begin; iter != end; ++iter) {
VisitNextFunctions(*iter, [&](FT f) { starts.emplace_back(f); });
}
common::BfsWalker<FT> bfs_walker{BfsVisitNextFunction};
bfs_walker(starts.begin(), starts.end(), FunctionVisitor);
}
template <typename VarIterT>
void operator()(VarIterT begin,
VarIterT end,
const VariableVisitorT& VariableVisitor,
const FunctionVisitorT& FunctionVisitor) const {
std::queue<VT> variables_queue{};
std::unordered_set<VT> queued_variables{};
std::queue<FT> functions_queue{};
std::unordered_set<FT> queued_functions{};
const auto& TryEnqueueVaraible = [&](VT variable) {
if (queued_variables.count(variable) == 0) {
variables_queue.push(variable);
queued_variables.insert(variable);
}
};
const auto& TryEnqueueFunction = [&](FT function) {
if (queued_functions.count(function) == 0) {
functions_queue.push(function);
queued_functions.insert(function);
}
};
for (VarIterT iter = begin; iter != end; ++iter) {
TryEnqueueVaraible(*iter);
}
while (!functions_queue.empty() || !variables_queue.empty()) {
if (!functions_queue.empty()) {
FT function = functions_queue.front();
functions_queue.pop();
FunctionVisitor(function);
VisitOutputVariables(function, TryEnqueueVaraible);
}
if (!variables_queue.empty()) {
VT variable = variables_queue.front();
variables_queue.pop();
VariableVisitor(variable);
VisitNextFunctions(variable, [&](FT function) {
size_t num_unfinished_inputs = 0;
VisitInputVariables(function, [&](VT in_variable) {
num_unfinished_inputs +=
(queued_variables.count(in_variable) > 0 ? 0 : 1);
});
if (num_unfinished_inputs == 0) {
TryEnqueueFunction(function);
}
});
}
}
}
// tNext [Function] <- Variable
F4VVisitor VisitNextFunctions;
// tIn [Variable] <- Function
V4FVisitor VisitInputVariables;
// tOut [Variable] <- Function
V4FVisitor VisitOutputVariables;
};
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// TODO(yifan): Add unittest here
#include "paddle/cinn/common/equation_graph_topo_walker.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
namespace adt {
namespace common {
using VT = int;
using FT = std::string;
/*
Graph ex:
1-> "1->10" -> 10
2-> "2->20" -> 20
*/
TEST(EquationGraphTopoWalker, simple1) {
auto F4V = [](VT variable, const std::function<void(FT)>& visitor) {
if (variable == 1) {
visitor("1->10");
} else if (variable == 2) {
visitor("2->20");
}
};
auto InV4F = [](FT function, const std::function<void(VT)>& visitor) {
if (function == "1->10") {
visitor(1);
} else if (function == "2->20") {
visitor(2);
}
};
auto OutV4F = [](FT function, const std::function<void(VT)>& visitor) {
if (function == "1->10") {
visitor(10);
} else if (function == "2->20") {
visitor(20);
}
};
cinn::EquationGraphTopoWalker<VT, FT> walker(F4V, InV4F, OutV4F);
std::vector<FT> outputs;
std::function<void(FT)> FunctionVisitor = [&](FT function) {
outputs.push_back(function);
};
walker.WalkFunction(1, FunctionVisitor);
std::vector<FT> expected{"1->10"};
EXPECT_TRUE((outputs == expected));
}
/*
Graph ex:
1 -> "1->10, 1->11" -> 10
-> 11
2 -> "2->20" -> 20
3 -> "3->30, 3->31" -> 30
-> 31
*/
TEST(EquationGraphTopoWalker, simple2) {
auto F4V = [](VT variable, const std::function<void(FT)>& visitor) {
if (variable == 1) {
visitor("1->10, 1->11");
} else if (variable == 2) {
visitor("2->20");
} else if (variable == 3) {
visitor("3->30, 3->31");
}
};
auto InV4F = [](FT function, const std::function<void(VT)>& visitor) {
if (function == "1->10, 1->11") {
visitor(1);
} else if (function == "2->20") {
visitor(2);
} else if (function == "3->30, 3->31") {
visitor(3);
}
};
auto OutV4F = [](FT function, const std::function<void(VT)>& visitor) {
if (function == "1->10, 1->11") {
visitor(10);
visitor(11);
} else if (function == "2->20") {
visitor(20);
} else if (function == "3->30, 3->31") {
visitor(30);
visitor(31);
}
};
cinn::EquationGraphTopoWalker<VT, FT> walker(F4V, InV4F, OutV4F);
std::vector<VT> outputs;
std::function<void(VT)> VariableVisitor = [&](VT variable) {
outputs.push_back(variable);
};
walker.WalkVariable(1, VariableVisitor);
std::vector<VT> expected{1, 10, 11};
EXPECT_TRUE((outputs == expected));
}
} // namespace common
} // namespace adt
......@@ -18,10 +18,9 @@
#include <unordered_set>
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/cast_simplify.h"
namespace cinn {
namespace common {
......@@ -147,7 +146,7 @@ Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
for (int i = 0; i < shape.size(); i++) {
CHECK_EQ(shape[i].type(), Int(32));
Expr indice_prod = indices[i];
optim::CastSimplify(&indice_prod);
optim::SimplifyCast(&indice_prod);
for (int j = i + 1; j < shape.size(); j++) {
indice_prod = RampRelatedMul(indice_prod, shape[j]);
}
......@@ -250,8 +249,8 @@ Expr or_all(const std::vector<Expr> &conds) {
}
void CheckTensorUniqueInExpr(Expr expr) {
auto tensor_uniq =
ir::CollectIRNodes(expr, [](const Expr *x) { return x->as_tensor(); });
auto tensor_uniq = ir::ir_utils::CollectIRNodes(
expr, [](const Expr *x) { return x->as_tensor(); });
absl::flat_hash_map<std::string, const ir::_Tensor_ *> tensor_names;
for (auto &t : tensor_uniq) {
auto *tp = t.as_tensor();
......@@ -270,9 +269,9 @@ void CheckBufferUniqueInExpr(Expr expr) {
// the buffers exists in tensor and lowered functions.
CheckTensorUniqueInExpr(expr);
auto tensors =
ir::CollectIRNodes(expr, [](const Expr *x) { return x->as_tensor(); });
auto funcs = ir::CollectIRNodes(
auto tensors = ir::ir_utils::CollectIRNodes(
expr, [](const Expr *x) { return x->as_tensor(); });
auto funcs = ir::ir_utils::CollectIRNodes(
expr, [](const Expr *x) { return x->as_lowered_func(); });
absl::flat_hash_map<std::string, const ir::_Buffer_ *> buffer_name;
......
......@@ -69,8 +69,8 @@
#define USE_FUSION_PASS(pass_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_fusion_pass_##pass_name, \
"USE_OP_ITSELF must be called in global namespace"); \
extern int TouchFusionPassRegistrar_##pass_name(); \
[[maybe_unused]] static int __use_fusion_pass_##pass_name##_ = \
TouchFusionPassRegistrar_##pass_name()
__use_cinn_fusion_pass_##pass_name, \
"USE_FUSION_PASS must be called in global namespace"); \
extern int TouchCinnFusionPassRegistrar_##pass_name(); \
[[maybe_unused]] static int __use_cinn_fusion_pass_##pass_name##_ = \
TouchCinnFusionPassRegistrar_##pass_name()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "paddle/cinn/common/topo_walker.h"
namespace cinn::common {
template <typename NodeT, typename IterT>
std::function<bool(NodeT)> MakeIsReachableFromSrcPredicator(
const TopoWalker<NodeT>& walker, IterT src_begin, IterT src_end) {
auto nodes = std::make_shared<std::unordered_set<NodeT>>();
nodes->insert(src_begin, src_end);
walker(src_begin, src_end, [&](NodeT node) { nodes->insert(node); });
return [nodes](NodeT node) { return nodes->count(node) > 0; };
}
} // namespace cinn::common
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "paddle/cinn/common/make_is_reachable_from_src_predicator.h"
#include "paddle/cinn/common/topo_walker.h"
namespace cinn::common {
template <typename NodeT, typename IterT>
common::TopoWalker<NodeT> MakeSubgraphWalker(
const common::TopoWalker<NodeT>& walker,
IterT src_begin,
IterT src_end,
IterT sink_begin,
IterT sink_end) {
common::TopoWalker<NodeT> reversed_walker(walker.VisitNextNodes,
walker.VisitPrevNodes);
auto ReachableToOneSrc =
common::MakeIsReachableFromSrcPredicator<NodeT, IterT>(
walker, src_begin, src_end);
auto ReachableToOneSink =
common::MakeIsReachableFromSrcPredicator<NodeT, IterT>(
reversed_walker, sink_begin, sink_end);
auto VisitPrevNodes = [ReachableToOneSrc, ReachableToOneSink, walker](
NodeT node,
const std::function<void(NodeT)>& Visitor) {
walker.VisitPrevNodes(node, [&](NodeT in_node) {
if (ReachableToOneSrc(in_node) && ReachableToOneSink(in_node)) {
Visitor(in_node);
}
});
};
auto VisitNextNodes = [ReachableToOneSrc, ReachableToOneSink, walker](
NodeT node,
const std::function<void(NodeT)>& Visitor) {
walker.VisitNextNodes(node, [&](NodeT out_node) {
if (ReachableToOneSrc(out_node) && ReachableToOneSink(out_node)) {
Visitor(out_node);
}
});
};
return common::TopoWalker<NodeT>(VisitPrevNodes, VisitNextNodes);
}
} // namespace cinn::common
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/common/nvgpu_dev_info.h"
namespace cinn {
namespace common {
std::array<int, 3> NVGPUDevInfo::GetMaxGridDims() const {
std::array<int, 3> ret;
ret[0] = prop_.maxGridSize[0];
ret[1] = prop_.maxGridSize[1];
ret[2] = prop_.maxGridSize[2];
return ret;
}
std::array<int, 3> NVGPUDevInfo::GetMaxBlockDims() const {
std::array<int, 3> ret;
ret[0] = prop_.maxThreadsDim[0];
ret[1] = prop_.maxThreadsDim[1];
ret[2] = prop_.maxThreadsDim[2];
return ret;
}
int NVGPUDevInfo::GetMultiProcessorCount() const {
return prop_.multiProcessorCount;
}
int NVGPUDevInfo::GetMaxThreadsPerMultiProcessor() const {
return prop_.maxThreadsPerMultiProcessor;
}
int NVGPUDevInfo::GetMaxThreadsPerBlock() const {
return prop_.maxThreadsPerBlock;
}
} // namespace common
} // namespace cinn
#endif
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef CINN_WITH_CUDA
#include <ostream>
#include <string>
#include <vector>
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/dev_info_base.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/common/target.h"
namespace cinn {
namespace common {
class NVGPUDevInfo : public DevInfoBase {
public:
explicit NVGPUDevInfo(int device_num = 0) : DevInfoBase(device_num) {
CUDA_CALL(cudaGetDeviceProperties(&prop_, device_num));
}
std::array<int, 3> GetMaxGridDims() const;
std::array<int, 3> GetMaxBlockDims() const;
int GetMultiProcessorCount() const;
int GetMaxThreadsPerMultiProcessor() const;
int GetMaxThreadsPerBlock() const;
private:
cudaDeviceProp prop_;
};
} // namespace common
} // namespace cinn
#endif
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef CINN_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <driver_types.h>
#endif
......@@ -20,12 +21,20 @@
#include <sstream>
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
namespace cinn {
namespace common {
Target::Target(OS o,
Arch a,
Bit b,
const std::vector<Feature> &features,
const std::vector<Lib> &libs)
: os(o), arch(a), bits(b), features(features), libs(libs) {}
bool Target::operator==(const Target &other) const {
return os == other.os && //
arch == other.arch && //
......
......@@ -14,6 +14,7 @@
#pragma once
#include <array>
#include <ostream>
#include <string>
#include <vector>
......@@ -71,8 +72,7 @@ struct Target {
Arch a = Arch::Unk,
Bit b = Bit::Unk,
const std::vector<Feature>& features = {},
const std::vector<Lib>& libs = {})
: os(o), arch(a), bits(b), features(features), libs(libs) {}
const std::vector<Lib>& libs = {});
bool defined() const {
return os != OS::Unk && arch != Arch::Unk && bits != Bit::Unk;
......
......@@ -26,16 +26,17 @@ namespace common {
template <typename NodeType>
class TopoWalker final {
public:
TopoWalker(const TopoWalker&) = delete;
TopoWalker(TopoWalker&&) = delete;
TopoWalker(const TopoWalker&) = default;
TopoWalker(TopoWalker&&) = default;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
TopoWalker(const NodesVisitorType& VisitPrevNodes,
const NodesVisitorType& VisitNextNodes)
: VisitPrevNodes_(VisitPrevNodes), VisitNextNodes_(VisitNextNodes) {}
TopoWalker(const NodesVisitorType& VisitPrevNodesValue,
const NodesVisitorType& VisitNextNodesValue)
: VisitPrevNodes(VisitPrevNodesValue),
VisitNextNodes(VisitNextNodesValue) {}
void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
......@@ -61,9 +62,9 @@ class TopoWalker final {
NodeType node = node_queue.front();
node_queue.pop();
NodeHandler(node);
VisitNextNodes_(node, [&](NodeType node) {
VisitNextNodes(node, [&](NodeType node) {
size_t num_unfinished_inputs = 0;
VisitPrevNodes_(node, [&](NodeType in_node) {
VisitPrevNodes(node, [&](NodeType in_node) {
num_unfinished_inputs += (queued_nodes.count(in_node) > 0 ? 0 : 1);
});
if (num_unfinished_inputs == 0) {
......@@ -73,9 +74,8 @@ class TopoWalker final {
}
}
private:
NodesVisitorType VisitPrevNodes_;
NodesVisitorType VisitNextNodes_;
NodesVisitorType VisitPrevNodes;
NodesVisitorType VisitNextNodes;
};
} // namespace common
......
......@@ -72,16 +72,18 @@ std::shared_ptr<ComputationContext> CompileProgram(
}
ctx->scope = hlir::framework::BuildScope(target, ctx->graph, scope);
ctx->graph_compiler.reset(
new hlir::framework::GraphCompiler(target, ctx->scope, ctx->graph));
std::unordered_set<std::string> fetch_var_ids;
for (auto &out : outputs) {
fetch_var_ids.insert(out->id);
}
ctx->program = ctx->graph_compiler->Build(options, std::move(fetch_var_ids))
.runtime_program;
ctx->compile_options.graph = ctx->graph;
ctx->compile_options.scope = ctx->scope;
ctx->compile_options.fetch_var_ids = fetch_var_ids;
ctx->graph_compiler.reset(
new hlir::framework::GraphCompiler(ctx->compile_options));
ctx->program = ctx->graph_compiler->Build();
if (ctx->compile_options.do_prerun) {
ctx->program->PreRun();
}
......
......@@ -27,8 +27,7 @@ struct ComputationContext;
class CinnComputation {
public:
struct CompileOptions
: public hlir::framework::GraphCompiler::CompileOptions {
struct CompileOptions : public hlir::framework::CompilationContext {
bool use_decomposer = false;
bool do_prerun = true;
bool use_default_passes = true;
......
......@@ -23,7 +23,7 @@
#include "paddle/cinn/frontend/pass/use_program_pass.h"
#include "paddle/cinn/frontend/program_pass.h"
DEFINE_string(model_dir, "", "");
PD_DEFINE_string(model_dir, "", "");
namespace cinn {
namespace frontend {
......
......@@ -86,7 +86,8 @@ TEST(Decomposer, softmax_decomposer) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
std::vector<float> x(n * c * h * w);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment