Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
...@@ -200,7 +200,8 @@ TEST(Decomposer, BatchNormTrain) { ...@@ -200,7 +200,8 @@ TEST(Decomposer, BatchNormTrain) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build(); auto run_program = gc.Build();
// set input // set input
...@@ -399,7 +400,8 @@ TEST(Decomposer, BatchNormGrad) { ...@@ -399,7 +400,8 @@ TEST(Decomposer, BatchNormGrad) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build(); auto run_program = gc.Build();
// set input // set input
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "paddle/cinn/frontend/program_pass.h" #include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h" #include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h" #include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
...@@ -208,7 +209,8 @@ void RunAndCheckShape(NetBuilder* builder, ...@@ -208,7 +209,8 @@ void RunAndCheckShape(NetBuilder* builder,
auto graph = std::make_shared<hlir::framework::Graph>(prog, target); auto graph = std::make_shared<hlir::framework::Graph>(prog, target);
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses()); hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
std::vector<std::vector<T>> input_vecs_internal; std::vector<std::vector<T>> input_vecs_internal;
......
...@@ -38,7 +38,8 @@ TEST(Decomposer, top_k_decomposer) { ...@@ -38,7 +38,8 @@ TEST(Decomposer, top_k_decomposer) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build(); auto run_program = gc.Build();
std::vector<float> x(10 * 5); std::vector<float> x(10 * 5);
......
...@@ -19,12 +19,13 @@ ...@@ -19,12 +19,13 @@
#include "paddle/cinn/frontend/optimize.h" #include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/syntax.h" #include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/runtime/flags.h"
DECLARE_bool(enable_auto_tuner); PD_DECLARE_bool(enable_auto_tuner);
namespace cinn::frontend { namespace cinn::frontend {
...@@ -120,10 +121,8 @@ void Interpreter::Impl::Build(const Target& target, ...@@ -120,10 +121,8 @@ void Interpreter::Impl::Build(const Target& target,
graph->attrs["model_name"] = std::make_shared<absl::any>(model_name); graph->attrs["model_name"] = std::make_shared<absl::any>(model_name);
scope_ = hlir::framework::BuildScope(target, graph, scope_); scope_ = hlir::framework::BuildScope(target, graph, scope_);
graph_compiler_.reset( hlir::framework::CompilationContext context(graph, scope_, target);
new hlir::framework::GraphCompiler(target, scope_, graph)); context.with_instantiate_variables = true;
hlir::framework::GraphCompiler::CompileOptions options;
options.with_instantiate_variables = true;
if (FLAGS_enable_auto_tuner) { if (FLAGS_enable_auto_tuner) {
VLOG(4) << "Compile with auto-tune"; VLOG(4) << "Compile with auto-tune";
auto_schedule::AutoTuner auto_tuner(target, graph.get()); auto_schedule::AutoTuner auto_tuner(target, graph.get());
...@@ -131,10 +130,10 @@ void Interpreter::Impl::Build(const Target& target, ...@@ -131,10 +130,10 @@ void Interpreter::Impl::Build(const Target& target,
graph_compiler_.get()); graph_compiler_.get());
auto_schedule::TuningOptions tuning_options; auto_schedule::TuningOptions tuning_options;
auto_schedule::TuningResult tuning_result = auto_tuner.Tune(tuning_options); auto_schedule::TuningResult tuning_result = auto_tuner.Tune(tuning_options);
options.Apply(tuning_result); context.ApplyTuningResult(tuning_result);
} }
runtime_program_ = graph_compiler_ = std::make_unique<hlir::framework::GraphCompiler>(context);
graph_compiler_->Build(options, std::move(fetch_var_ids)).runtime_program; runtime_program_ = graph_compiler_->Build();
runtime_program_->PreRun(); runtime_program_->PreRun();
} }
...@@ -150,4 +149,4 @@ Interpreter::Interpreter( ...@@ -150,4 +149,4 @@ Interpreter::Interpreter(
} // namespace cinn::frontend } // namespace cinn::frontend
cinn::frontend::Interpreter::~Interpreter() {} cinn::frontend::Interpreter::~Interpreter() = default;
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/cinn/runtime/use_extern_funcs.h" #include "paddle/cinn/runtime/use_extern_funcs.h"
DEFINE_string(model_dir, "", ""); PD_DEFINE_string(model_dir, "", "");
namespace cinn::frontend { namespace cinn::frontend {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "paddle/cinn/frontend/syntax.h" #include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h" #include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/tensor.h" #include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/utils/data_util.h" #include "paddle/cinn/utils/data_util.h"
...@@ -99,7 +100,8 @@ TEST(net_build, program_execute_multi_elementwise_add) { ...@@ -99,7 +100,8 @@ TEST(net_build, program_execute_multi_elementwise_add) {
LOG(INFO) << "graph:\n" << graph->Visualize(); LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A"); scope->Var<hlir::framework::Tensor>("A");
...@@ -139,7 +141,8 @@ TEST(net_build, program_execute_fc) { ...@@ -139,7 +141,8 @@ TEST(net_build, program_execute_fc) {
LOG(INFO) << "graph:\n" << graph->Visualize(); LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(a.id())); scope->Var<hlir::framework::Tensor>(std::string(a.id()));
...@@ -183,7 +186,8 @@ TEST(net_build, program_execute_multi_elementwise_add_bf16) { ...@@ -183,7 +186,8 @@ TEST(net_build, program_execute_multi_elementwise_add_bf16) {
LOG(INFO) << "graph:\n" << graph->Visualize(); LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A"); scope->Var<hlir::framework::Tensor>("A");
...@@ -224,7 +228,8 @@ TEST(net_build, program_execute_fc_bf16) { ...@@ -224,7 +228,8 @@ TEST(net_build, program_execute_fc_bf16) {
LOG(INFO) << "graph:\n" << graph->Visualize(); LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(a.id())); scope->Var<hlir::framework::Tensor>(std::string(a.id()));
...@@ -285,7 +290,8 @@ TEST(net_build, program_execute_pool2d) { ...@@ -285,7 +290,8 @@ TEST(net_build, program_execute_pool2d) {
std::unordered_set<std::string> fetch_ids; std::unordered_set<std::string> fetch_ids;
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -318,7 +324,8 @@ TEST(net_build, program_execute_reverse) { ...@@ -318,7 +324,8 @@ TEST(net_build, program_execute_reverse) {
LOG(INFO) << "graph:\n" << graph->Visualize(); LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -349,7 +356,8 @@ TEST(net_build, program_execute_gather) { ...@@ -349,7 +356,8 @@ TEST(net_build, program_execute_gather) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input1.id())); scope->Var<hlir::framework::Tensor>(std::string(input1.id()));
...@@ -409,7 +417,8 @@ TEST(net_build, program_execute_gather_nd) { ...@@ -409,7 +417,8 @@ TEST(net_build, program_execute_gather_nd) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input1.id())); scope->Var<hlir::framework::Tensor>(std::string(input1.id()));
...@@ -469,7 +478,8 @@ TEST(net_build, program_execute_cast) { ...@@ -469,7 +478,8 @@ TEST(net_build, program_execute_cast) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -523,7 +533,8 @@ TEST(net_build, program_execute_squeeze_case0) { ...@@ -523,7 +533,8 @@ TEST(net_build, program_execute_squeeze_case0) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -582,7 +593,8 @@ TEST(net_build, program_execute_squeeze_case1) { ...@@ -582,7 +593,8 @@ TEST(net_build, program_execute_squeeze_case1) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -641,7 +653,8 @@ TEST(net_build, program_execute_squeeze_case2) { ...@@ -641,7 +653,8 @@ TEST(net_build, program_execute_squeeze_case2) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -699,7 +712,8 @@ TEST(net_build, program_execute_squeeze_case3) { ...@@ -699,7 +712,8 @@ TEST(net_build, program_execute_squeeze_case3) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -757,7 +771,8 @@ TEST(net_build, program_execute_squeeze_case4) { ...@@ -757,7 +771,8 @@ TEST(net_build, program_execute_squeeze_case4) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -813,7 +828,8 @@ TEST(net_build, program_execute_argsort) { ...@@ -813,7 +828,8 @@ TEST(net_build, program_execute_argsort) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -874,7 +890,8 @@ TEST(net_build, program_execute_sort) { ...@@ -874,7 +890,8 @@ TEST(net_build, program_execute_sort) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -934,7 +951,8 @@ TEST(net_build, program_execute_arange_float) { ...@@ -934,7 +951,8 @@ TEST(net_build, program_execute_arange_float) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(out->id)); scope->Var<hlir::framework::Tensor>(std::string(out->id));
...@@ -975,7 +993,8 @@ TEST(net_build, program_execute_arange_int) { ...@@ -975,7 +993,8 @@ TEST(net_build, program_execute_arange_int) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(out->id)); scope->Var<hlir::framework::Tensor>(std::string(out->id));
...@@ -1018,7 +1037,8 @@ TEST(net_build, program_argmax_case1) { ...@@ -1018,7 +1037,8 @@ TEST(net_build, program_argmax_case1) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -1092,7 +1112,8 @@ TEST(net_build, program_argmax_case2) { ...@@ -1092,7 +1112,8 @@ TEST(net_build, program_argmax_case2) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -1170,7 +1191,8 @@ TEST(net_build, program_argmin_case1) { ...@@ -1170,7 +1191,8 @@ TEST(net_build, program_argmin_case1) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -1247,7 +1269,8 @@ TEST(net_build, program_argmin_case2) { ...@@ -1247,7 +1269,8 @@ TEST(net_build, program_argmin_case2) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -1324,7 +1347,8 @@ TEST(net_build, program_execute_repeat_axis_0) { ...@@ -1324,7 +1347,8 @@ TEST(net_build, program_execute_repeat_axis_0) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -1379,7 +1403,8 @@ TEST(net_build, program_execute_repeat_axis_1) { ...@@ -1379,7 +1403,8 @@ TEST(net_build, program_execute_repeat_axis_1) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
...@@ -1440,7 +1465,8 @@ TEST(net_build, program_execute_one_hot) { ...@@ -1440,7 +1465,8 @@ TEST(net_build, program_execute_one_hot) {
auto graph = Optimize(&program, fetch_ids, target); auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id())); scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......
...@@ -67,11 +67,23 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc, ...@@ -67,11 +67,23 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc,
"but here cannot found! Please check."; "but here cannot found! Please check.";
} }
cinn::utils::ShapeType input_shape(ctx.GetVar(x_names.front())->shape);
auto axis = utils::GetAttrOrDefault<int>(op_desc, "axis", 0); auto axis = utils::GetAttrOrDefault<int>(op_desc, "axis", 0);
axis = axis >= 0 ? axis : axis + input_shape.size() + 1;
cinn::utils::ShapeType output_shape(input_shape);
output_shape.insert(output_shape.begin() + axis, 1);
std::vector<Variable> xs; std::vector<Variable> xs;
for (const auto& name : x_names) { for (const auto& name : x_names) {
xs.emplace_back(ctx.GetVar(name)); auto x = ctx.GetVar(name);
CHECK(x->shape == input_shape)
<< "All input shape of [stack] should be the same, be the input "
<< x->id << "'s shape [" << cinn::utils::Join(x->shape, ", ")
<< "] not equal to "
<< "the first input " << ctx.GetVar(x_names.front())->id << "'s shape ["
<< cinn::utils::Join(input_shape, ", ") << "]";
xs.emplace_back(ctx.Builder()->Reshape(x, output_shape));
} }
auto err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) { auto err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) {
...@@ -83,39 +95,10 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc, ...@@ -83,39 +95,10 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc,
<< "] not equal to the first input " << xs.front()->id << "'s dtype [" << "] not equal to the first input " << xs.front()->id << "'s dtype ["
<< xs.front()->type << "]"; << xs.front()->type << "]";
err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) {
return x->shape != xs.front()->shape;
});
CHECK(err_x == xs.end())
<< "All input shape of [stack] should be the same, be the input "
<< (*err_x)->id << "'s shape ["
<< cinn::utils::Join((*err_x)->shape, ", ") << "] not equal to "
<< "the first input " << xs.front()->id << "'s shape ["
<< cinn::utils::Join(xs.front()->shape, ", ") << "]";
auto concat_out = ctx.Builder()->Concat(xs, axis); auto concat_out = ctx.Builder()->Concat(xs, axis);
int rank = concat_out->shape.size(); ctx.AddVar(out_name, concat_out);
axis = axis >= 0 ? axis : axis + rank; ctx.AddVarModelToProgram(out_name, concat_out->id);
CHECK(axis >= 0 && axis < rank)
<< "The axis of stack should >=0 and <rank(x)! Please check.";
// N * [A, B] with axis=0 --> [N, A, B]; N * [A, B] with axis=1 --> [A, N, B];
cinn::utils::ShapeType new_shape;
for (int i = 0; i < rank; ++i) {
auto dim = concat_out->shape[i];
if (i != axis) {
new_shape.emplace_back(dim);
} else {
new_shape.emplace_back(xs.size());
// the shape same ensure `dim % xs.size() == 0`
new_shape.emplace_back(dim / xs.size());
}
}
auto out = ctx.Builder()->Reshape(concat_out, new_shape);
ctx.AddVar(out_name, out);
ctx.AddVarModelToProgram(out_name, out->id);
} }
void SplitOpMapper(const paddle::cpp::OpDesc& op_desc, void SplitOpMapper(const paddle::cpp::OpDesc& op_desc,
......
...@@ -29,15 +29,16 @@ ...@@ -29,15 +29,16 @@
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/runtime/flags.h"
DECLARE_bool(cinn_use_fill_constant_folding); PD_DECLARE_bool(cinn_use_fill_constant_folding);
DECLARE_bool(cinn_use_op_fusion); PD_DECLARE_bool(cinn_use_op_fusion);
DECLARE_bool(cinn_use_common_subexpression_elimination); PD_DECLARE_bool(cinn_use_common_subexpression_elimination);
DECLARE_string(cinn_check_fusion_accuracy_pass); PD_DECLARE_string(cinn_check_fusion_accuracy_pass);
DECLARE_bool(cinn_use_custom_call); PD_DECLARE_bool(cinn_use_custom_call);
DECLARE_bool(use_reduce_split_pass); PD_DECLARE_bool(use_reduce_split_pass);
DECLARE_bool(cinn_use_dense_merge_pass); PD_DECLARE_bool(cinn_use_dense_merge_pass);
DECLARE_string(cinn_custom_call_deny_ops); PD_DECLARE_string(cinn_custom_call_deny_ops);
DECLARE_bool(general_fusion_merge_pass); PD_DECLARE_bool(general_fusion_merge_pass);
PD_DECLARE_bool(cinn_use_cutlass);
namespace cinn { namespace cinn {
namespace frontend { namespace frontend {
...@@ -58,6 +59,7 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { ...@@ -58,6 +59,7 @@ OptimizeOptions DefaultTrainingOptimizeOptions() {
return FLAGS_cinn_custom_call_deny_ops.find(op) != std::string::npos; return FLAGS_cinn_custom_call_deny_ops.find(op) != std::string::npos;
}; };
bool is_gemm_use_cublas = FLAGS_cinn_use_custom_call && bool is_gemm_use_cublas = FLAGS_cinn_use_custom_call &&
!FLAGS_cinn_use_cutlass &&
!can_find_custom_call_deny_op("matmul") && !can_find_custom_call_deny_op("matmul") &&
!can_find_custom_call_deny_op("cublas_gemm") && !can_find_custom_call_deny_op("cublas_gemm") &&
!can_find_custom_call_deny_op("cublas_matmul"); !can_find_custom_call_deny_op("cublas_matmul");
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
#include "paddle/cinn/frontend/paddle/model_parser.h" #include "paddle/cinn/frontend/paddle/model_parser.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/utils/flags.h"
DEFINE_string(model_dir, "<NOTEXIST>", "model directory path"); PD_DEFINE_string(model_dir, "<NOTEXIST>", "model directory path");
namespace cinn::frontend::paddle { namespace cinn::frontend::paddle {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include "paddle/cinn/frontend/var_type_utils.h" #include "paddle/cinn/frontend/var_type_utils.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
DECLARE_double(cinn_infer_model_version); PD_DECLARE_double(cinn_infer_model_version);
namespace cinn { namespace cinn {
namespace frontend { namespace frontend {
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/cinn/frontend/decomposer/test_helper.h" #include "paddle/cinn/frontend/decomposer/test_helper.h"
#include "paddle/cinn/runtime/use_extern_funcs.h" #include "paddle/cinn/runtime/use_extern_funcs.h"
DEFINE_string(model_dir, "", ""); PD_DEFINE_string(model_dir, "", "");
namespace cinn { namespace cinn {
namespace frontend { namespace frontend {
...@@ -69,7 +69,8 @@ void RunProgram(const Target& target, Program* prog) { ...@@ -69,7 +69,8 @@ void RunProgram(const Target& target, Program* prog) {
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "paddle/cinn/frontend/paddle/pb/program_desc.h" #include "paddle/cinn/frontend/paddle/pb/program_desc.h"
#include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/hlir/framework/node.h"
DECLARE_double(cinn_infer_model_version); PD_DECLARE_double(cinn_infer_model_version);
namespace cinn { namespace cinn {
namespace frontend { namespace frontend {
......
...@@ -73,7 +73,8 @@ TEST(DecomposePass, basic) { ...@@ -73,7 +73,8 @@ TEST(DecomposePass, basic) {
auto graph = std::make_shared<hlir::framework::Graph>(prog, target); auto graph = std::make_shared<hlir::framework::Graph>(prog, target);
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses()); hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A"); scope->Var<hlir::framework::Tensor>("A");
......
...@@ -61,7 +61,8 @@ std::unordered_map<std::string, hlir::framework::Tensor> RunWithProgram( ...@@ -61,7 +61,8 @@ std::unordered_map<std::string, hlir::framework::Tensor> RunWithProgram(
hlir::framework::ApplyPasses(graph.get(), {"InferShape"}); hlir::framework::ApplyPasses(graph.get(), {"InferShape"});
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses()); hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
VLOG(1) << "graph:\n" << graph->Visualize(); VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
for (auto& data : input_data) { for (auto& data : input_data) {
scope->Var<hlir::framework::Tensor>(data.first); scope->Var<hlir::framework::Tensor>(data.first);
......
...@@ -40,7 +40,8 @@ std::vector<float> RunWithProgram(const Program& program, ...@@ -40,7 +40,8 @@ std::vector<float> RunWithProgram(const Program& program,
hlir::framework::ApplyPasses(graph.get(), {"InferShape"}); hlir::framework::ApplyPasses(graph.get(), {"InferShape"});
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses()); hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
VLOG(1) << "graph:\n" << graph->Visualize(); VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
runtime_program->Execute(); runtime_program->Execute();
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#pragma once #pragma once
#include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
...@@ -36,13 +35,15 @@ ...@@ -36,13 +35,15 @@
#include "paddle/cinn/frontend/syntax.h" #include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h" #include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h" #include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h" #include "paddle/cinn/utils/data_util.h"
#include "paddle/utils/flags.h"
DECLARE_bool(cinn_use_op_fusion); PD_DECLARE_bool(cinn_use_op_fusion);
namespace cinn { namespace cinn {
namespace frontend { namespace frontend {
...@@ -79,14 +80,12 @@ inline void RunGraph(std::shared_ptr<hlir::framework::Graph> graph, ...@@ -79,14 +80,12 @@ inline void RunGraph(std::shared_ptr<hlir::framework::Graph> graph,
hlir::framework::ApplyPasses(graph.get(), graph_passes); hlir::framework::ApplyPasses(graph.get(), graph_passes);
VLOG(3) << "Graph Viz:\n" << graph->Visualize(); VLOG(3) << "Graph Viz:\n" << graph->Visualize();
BuildScope(target, graph, scope); BuildScope(target, graph, scope);
hlir::framework::GraphCompiler::CompileOptions options; hlir::framework::CompilationContext context(graph, scope, target);
options.attached_code = ""; context.attached_source_code = "";
options.with_instantiate_variables = true; context.with_instantiate_variables = true;
hlir::framework::GraphCompiler gc(target, scope, graph); context.fetch_var_ids.insert(output_ids.begin(), output_ids.end());
auto runtime_program = gc.Build(options, hlir::framework::GraphCompiler gc(context);
std::unordered_set<std::string>( auto runtime_program = gc.Build();
output_ids.begin(), output_ids.end()))
.runtime_program;
runtime_program->Execute(); runtime_program->Execute();
} }
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/cinn/frontend/syntax.h" #include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h" #include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
...@@ -40,7 +41,8 @@ void RunWithProgram(const Program& program, ...@@ -40,7 +41,8 @@ void RunWithProgram(const Program& program,
hlir::framework::ApplyPasses(graph.get(), {"InferShape"}); hlir::framework::ApplyPasses(graph.get(), {"InferShape"});
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses()); hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
VLOG(1) << "graph:\n" << graph->Visualize(); VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
runtime_program->Execute(); runtime_program->Execute();
} }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/cinn/frontend/pass/use_program_pass.h" #include "paddle/cinn/frontend/pass/use_program_pass.h"
#include "paddle/cinn/frontend/program_pass.h" #include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h" #include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
...@@ -117,11 +118,10 @@ class PassTest { ...@@ -117,11 +118,10 @@ class PassTest {
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses()); hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
auto scope = hlir::framework::BuildScope(target_, graph); auto scope = hlir::framework::BuildScope(target_, graph);
hlir::framework::GraphCompiler gc(target_, scope, graph); hlir::framework::CompilationContext context(graph, scope, target_);
hlir::framework::GraphCompiler::CompileOptions options; context.with_instantiate_variables = true;
options.with_instantiate_variables = true; hlir::framework::GraphCompiler gc(context);
auto result = gc.Build(options, std::move(fetch_var_ids)); auto runtime_program = std::move(gc.Build());
auto runtime_program = std::move(result.runtime_program);
for (auto& name : input_names) { for (auto& name : input_names) {
SetInputTensor(name, scope); SetInputTensor(name, scope);
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/cinn/frontend/syntax.h" #include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h" #include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
...@@ -68,7 +69,8 @@ std::vector<std::vector<float>> RunWithProgram( ...@@ -68,7 +69,8 @@ std::vector<std::vector<float>> RunWithProgram(
hlir::framework::ApplyPasses(graph.get(), {"InferShape", "OpFusionPass"}); hlir::framework::ApplyPasses(graph.get(), {"InferShape", "OpFusionPass"});
VLOG(1) << "graph:\n" << graph->Visualize(); VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
runtime_program->Execute(); runtime_program->Execute();
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/cinn/frontend/syntax.h" #include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h" #include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
...@@ -38,7 +39,8 @@ void RunWithProgram(const Program& program, ...@@ -38,7 +39,8 @@ void RunWithProgram(const Program& program,
auto graph = std::make_shared<hlir::framework::Graph>(program, target); auto graph = std::make_shared<hlir::framework::Graph>(program, target);
hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass"}); hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass"});
VLOG(1) << "graph:\n" << graph->Visualize(); VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
runtime_program->Execute(); runtime_program->Execute();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment