Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -200,7 +200,8 @@ TEST(Decomposer, BatchNormTrain) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
// set input
......@@ -399,7 +400,8 @@ TEST(Decomposer, BatchNormGrad) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
// set input
......
......@@ -27,6 +27,7 @@
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h"
......@@ -208,7 +209,8 @@ void RunAndCheckShape(NetBuilder* builder,
auto graph = std::make_shared<hlir::framework::Graph>(prog, target);
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
std::vector<std::vector<T>> input_vecs_internal;
......
......@@ -38,7 +38,8 @@ TEST(Decomposer, top_k_decomposer) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
std::vector<float> x(10 * 5);
......
......@@ -19,12 +19,13 @@
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(enable_auto_tuner);
PD_DECLARE_bool(enable_auto_tuner);
namespace cinn::frontend {
......@@ -120,10 +121,8 @@ void Interpreter::Impl::Build(const Target& target,
graph->attrs["model_name"] = std::make_shared<absl::any>(model_name);
scope_ = hlir::framework::BuildScope(target, graph, scope_);
graph_compiler_.reset(
new hlir::framework::GraphCompiler(target, scope_, graph));
hlir::framework::GraphCompiler::CompileOptions options;
options.with_instantiate_variables = true;
hlir::framework::CompilationContext context(graph, scope_, target);
context.with_instantiate_variables = true;
if (FLAGS_enable_auto_tuner) {
VLOG(4) << "Compile with auto-tune";
auto_schedule::AutoTuner auto_tuner(target, graph.get());
......@@ -131,10 +130,10 @@ void Interpreter::Impl::Build(const Target& target,
graph_compiler_.get());
auto_schedule::TuningOptions tuning_options;
auto_schedule::TuningResult tuning_result = auto_tuner.Tune(tuning_options);
options.Apply(tuning_result);
context.ApplyTuningResult(tuning_result);
}
runtime_program_ =
graph_compiler_->Build(options, std::move(fetch_var_ids)).runtime_program;
graph_compiler_ = std::make_unique<hlir::framework::GraphCompiler>(context);
runtime_program_ = graph_compiler_->Build();
runtime_program_->PreRun();
}
......@@ -150,4 +149,4 @@ Interpreter::Interpreter(
} // namespace cinn::frontend
cinn::frontend::Interpreter::~Interpreter() {}
cinn::frontend::Interpreter::~Interpreter() = default;
......@@ -18,7 +18,7 @@
#include "paddle/cinn/runtime/use_extern_funcs.h"
DEFINE_string(model_dir, "", "");
PD_DEFINE_string(model_dir, "", "");
namespace cinn::frontend {
......
......@@ -26,6 +26,7 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/utils/data_util.h"
......@@ -99,7 +100,8 @@ TEST(net_build, program_execute_multi_elementwise_add) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -139,7 +141,8 @@ TEST(net_build, program_execute_fc) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(a.id()));
......@@ -183,7 +186,8 @@ TEST(net_build, program_execute_multi_elementwise_add_bf16) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -224,7 +228,8 @@ TEST(net_build, program_execute_fc_bf16) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(a.id()));
......@@ -285,7 +290,8 @@ TEST(net_build, program_execute_pool2d) {
std::unordered_set<std::string> fetch_ids;
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -318,7 +324,8 @@ TEST(net_build, program_execute_reverse) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -349,7 +356,8 @@ TEST(net_build, program_execute_gather) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input1.id()));
......@@ -409,7 +417,8 @@ TEST(net_build, program_execute_gather_nd) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input1.id()));
......@@ -469,7 +478,8 @@ TEST(net_build, program_execute_cast) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -523,7 +533,8 @@ TEST(net_build, program_execute_squeeze_case0) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -582,7 +593,8 @@ TEST(net_build, program_execute_squeeze_case1) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -641,7 +653,8 @@ TEST(net_build, program_execute_squeeze_case2) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -699,7 +712,8 @@ TEST(net_build, program_execute_squeeze_case3) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -757,7 +771,8 @@ TEST(net_build, program_execute_squeeze_case4) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -813,7 +828,8 @@ TEST(net_build, program_execute_argsort) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -874,7 +890,8 @@ TEST(net_build, program_execute_sort) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -934,7 +951,8 @@ TEST(net_build, program_execute_arange_float) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(out->id));
......@@ -975,7 +993,8 @@ TEST(net_build, program_execute_arange_int) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(out->id));
......@@ -1018,7 +1037,8 @@ TEST(net_build, program_argmax_case1) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1092,7 +1112,8 @@ TEST(net_build, program_argmax_case2) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1170,7 +1191,8 @@ TEST(net_build, program_argmin_case1) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1247,7 +1269,8 @@ TEST(net_build, program_argmin_case2) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1324,7 +1347,8 @@ TEST(net_build, program_execute_repeat_axis_0) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1379,7 +1403,8 @@ TEST(net_build, program_execute_repeat_axis_1) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1440,7 +1465,8 @@ TEST(net_build, program_execute_one_hot) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......
......@@ -67,11 +67,23 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc,
"but here cannot found! Please check.";
}
cinn::utils::ShapeType input_shape(ctx.GetVar(x_names.front())->shape);
auto axis = utils::GetAttrOrDefault<int>(op_desc, "axis", 0);
axis = axis >= 0 ? axis : axis + input_shape.size() + 1;
cinn::utils::ShapeType output_shape(input_shape);
output_shape.insert(output_shape.begin() + axis, 1);
std::vector<Variable> xs;
for (const auto& name : x_names) {
xs.emplace_back(ctx.GetVar(name));
auto x = ctx.GetVar(name);
CHECK(x->shape == input_shape)
<< "All input shape of [stack] should be the same, be the input "
<< x->id << "'s shape [" << cinn::utils::Join(x->shape, ", ")
<< "] not equal to "
<< "the first input " << ctx.GetVar(x_names.front())->id << "'s shape ["
<< cinn::utils::Join(input_shape, ", ") << "]";
xs.emplace_back(ctx.Builder()->Reshape(x, output_shape));
}
auto err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) {
......@@ -83,39 +95,10 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc,
<< "] not equal to the first input " << xs.front()->id << "'s dtype ["
<< xs.front()->type << "]";
err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) {
return x->shape != xs.front()->shape;
});
CHECK(err_x == xs.end())
<< "All input shape of [stack] should be the same, be the input "
<< (*err_x)->id << "'s shape ["
<< cinn::utils::Join((*err_x)->shape, ", ") << "] not equal to "
<< "the first input " << xs.front()->id << "'s shape ["
<< cinn::utils::Join(xs.front()->shape, ", ") << "]";
auto concat_out = ctx.Builder()->Concat(xs, axis);
int rank = concat_out->shape.size();
axis = axis >= 0 ? axis : axis + rank;
CHECK(axis >= 0 && axis < rank)
<< "The axis of stack should >=0 and <rank(x)! Please check.";
// N * [A, B] with axis=0 --> [N, A, B]; N * [A, B] with axis=1 --> [A, N, B];
cinn::utils::ShapeType new_shape;
for (int i = 0; i < rank; ++i) {
auto dim = concat_out->shape[i];
if (i != axis) {
new_shape.emplace_back(dim);
} else {
new_shape.emplace_back(xs.size());
// the shape same ensure `dim % xs.size() == 0`
new_shape.emplace_back(dim / xs.size());
}
}
auto out = ctx.Builder()->Reshape(concat_out, new_shape);
ctx.AddVar(out_name, out);
ctx.AddVarModelToProgram(out_name, out->id);
ctx.AddVar(out_name, concat_out);
ctx.AddVarModelToProgram(out_name, concat_out->id);
}
void SplitOpMapper(const paddle::cpp::OpDesc& op_desc,
......
......@@ -29,15 +29,16 @@
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(cinn_use_fill_constant_folding);
DECLARE_bool(cinn_use_op_fusion);
DECLARE_bool(cinn_use_common_subexpression_elimination);
DECLARE_string(cinn_check_fusion_accuracy_pass);
DECLARE_bool(cinn_use_custom_call);
DECLARE_bool(use_reduce_split_pass);
DECLARE_bool(cinn_use_dense_merge_pass);
DECLARE_string(cinn_custom_call_deny_ops);
DECLARE_bool(general_fusion_merge_pass);
PD_DECLARE_bool(cinn_use_fill_constant_folding);
PD_DECLARE_bool(cinn_use_op_fusion);
PD_DECLARE_bool(cinn_use_common_subexpression_elimination);
PD_DECLARE_string(cinn_check_fusion_accuracy_pass);
PD_DECLARE_bool(cinn_use_custom_call);
PD_DECLARE_bool(use_reduce_split_pass);
PD_DECLARE_bool(cinn_use_dense_merge_pass);
PD_DECLARE_string(cinn_custom_call_deny_ops);
PD_DECLARE_bool(general_fusion_merge_pass);
PD_DECLARE_bool(cinn_use_cutlass);
namespace cinn {
namespace frontend {
......@@ -58,6 +59,7 @@ OptimizeOptions DefaultTrainingOptimizeOptions() {
return FLAGS_cinn_custom_call_deny_ops.find(op) != std::string::npos;
};
bool is_gemm_use_cublas = FLAGS_cinn_use_custom_call &&
!FLAGS_cinn_use_cutlass &&
!can_find_custom_call_deny_op("matmul") &&
!can_find_custom_call_deny_op("cublas_gemm") &&
!can_find_custom_call_deny_op("cublas_matmul");
......
......@@ -14,10 +14,10 @@
#include "paddle/cinn/frontend/paddle/model_parser.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/utils/flags.h"
DEFINE_string(model_dir, "<NOTEXIST>", "model directory path");
PD_DEFINE_string(model_dir, "<NOTEXIST>", "model directory path");
namespace cinn::frontend::paddle {
......
......@@ -27,7 +27,7 @@
#include "paddle/cinn/frontend/var_type_utils.h"
#include "paddle/cinn/hlir/op/use_ops.h"
DECLARE_double(cinn_infer_model_version);
PD_DECLARE_double(cinn_infer_model_version);
namespace cinn {
namespace frontend {
......
......@@ -20,7 +20,7 @@
#include "paddle/cinn/frontend/decomposer/test_helper.h"
#include "paddle/cinn/runtime/use_extern_funcs.h"
DEFINE_string(model_dir, "", "");
PD_DEFINE_string(model_dir, "", "");
namespace cinn {
namespace frontend {
......@@ -69,7 +69,8 @@ void RunProgram(const Target& target, Program* prog) {
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
for (size_t i = 0; i < input_names.size(); ++i) {
......
......@@ -21,7 +21,7 @@
#include "paddle/cinn/frontend/paddle/pb/program_desc.h"
#include "paddle/cinn/hlir/framework/node.h"
DECLARE_double(cinn_infer_model_version);
PD_DECLARE_double(cinn_infer_model_version);
namespace cinn {
namespace frontend {
......
......@@ -73,7 +73,8 @@ TEST(DecomposePass, basic) {
auto graph = std::make_shared<hlir::framework::Graph>(prog, target);
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......
......@@ -61,7 +61,8 @@ std::unordered_map<std::string, hlir::framework::Tensor> RunWithProgram(
hlir::framework::ApplyPasses(graph.get(), {"InferShape"});
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
for (auto& data : input_data) {
scope->Var<hlir::framework::Tensor>(data.first);
......
......@@ -40,7 +40,8 @@ std::vector<float> RunWithProgram(const Program& program,
hlir::framework::ApplyPasses(graph.get(), {"InferShape"});
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
runtime_program->Execute();
......
......@@ -14,7 +14,6 @@
#pragma once
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <algorithm>
......@@ -36,13 +35,15 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h"
#include "paddle/utils/flags.h"
DECLARE_bool(cinn_use_op_fusion);
PD_DECLARE_bool(cinn_use_op_fusion);
namespace cinn {
namespace frontend {
......@@ -79,14 +80,12 @@ inline void RunGraph(std::shared_ptr<hlir::framework::Graph> graph,
hlir::framework::ApplyPasses(graph.get(), graph_passes);
VLOG(3) << "Graph Viz:\n" << graph->Visualize();
BuildScope(target, graph, scope);
hlir::framework::GraphCompiler::CompileOptions options;
options.attached_code = "";
options.with_instantiate_variables = true;
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build(options,
std::unordered_set<std::string>(
output_ids.begin(), output_ids.end()))
.runtime_program;
hlir::framework::CompilationContext context(graph, scope, target);
context.attached_source_code = "";
context.with_instantiate_variables = true;
context.fetch_var_ids.insert(output_ids.begin(), output_ids.end());
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
runtime_program->Execute();
}
......
......@@ -25,6 +25,7 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
......@@ -40,7 +41,8 @@ void RunWithProgram(const Program& program,
hlir::framework::ApplyPasses(graph.get(), {"InferShape"});
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
runtime_program->Execute();
}
......
......@@ -23,6 +23,7 @@
#include "paddle/cinn/frontend/pass/use_program_pass.h"
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
......@@ -117,11 +118,10 @@ class PassTest {
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
auto scope = hlir::framework::BuildScope(target_, graph);
hlir::framework::GraphCompiler gc(target_, scope, graph);
hlir::framework::GraphCompiler::CompileOptions options;
options.with_instantiate_variables = true;
auto result = gc.Build(options, std::move(fetch_var_ids));
auto runtime_program = std::move(result.runtime_program);
hlir::framework::CompilationContext context(graph, scope, target_);
context.with_instantiate_variables = true;
hlir::framework::GraphCompiler gc(context);
auto runtime_program = std::move(gc.Build());
for (auto& name : input_names) {
SetInputTensor(name, scope);
......
......@@ -23,6 +23,7 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
......@@ -68,7 +69,8 @@ std::vector<std::vector<float>> RunWithProgram(
hlir::framework::ApplyPasses(graph.get(), {"InferShape", "OpFusionPass"});
VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
runtime_program->Execute();
......
......@@ -24,6 +24,7 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
......@@ -38,7 +39,8 @@ void RunWithProgram(const Program& program,
auto graph = std::make_shared<hlir::framework::Graph>(program, target);
hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass"});
VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
runtime_program->Execute();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment