Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -23,7 +23,7 @@
#include "paddle/cinn/hlir/pe/nn.h"
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/hlir/pe/transform.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/utils/string.h"
#ifdef CINN_WITH_CUDNN
......
......@@ -17,6 +17,7 @@
#include <iostream>
#include "absl/types/optional.h"
#include "paddle/cinn/adt/op_equation_context.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
......@@ -107,6 +108,13 @@ std::vector<Type> InferDtypeForElementwise(
return res;
}
void GenerateEquationsForElementwise(
cinn::adt::config::OpEquationContext *ctx) {
CHECK(ctx->GetInTensorsRanks().size() != 0)
<< "The inputs is empty! Please check again.";
ctx->Equal(ctx->GetInIteratorTuple(0), ctx->GetOutIteratorTuple(0));
}
std::vector<Type> InferDtypeForElementwiseBool(
const std::vector<Type> &inputs_type, const framework::AttrMapType &attrs) {
CHECK(!inputs_type.empty())
......@@ -157,23 +165,31 @@ std::shared_ptr<OpStrategy> StrategyForScale(
CHECK(pack_args[1].is_string());
std::string tensor_name = pack_args[1].operator std::string();
if (bias_after_scale) {
out = Compute(
A->shape,
[=](const std::vector<Expr> &indice) {
return ir::Cast::Make(A->type(), Expr(scale)) * A(indice) +
ir::Cast::Make(A->type(), Expr(bias));
},
tensor_name);
} else {
out = Compute(
A->shape,
[=](const std::vector<Expr> &indice) {
return ir::Cast::Make(A->type(), Expr(scale)) *
(A(indice) + ir::Cast::Make(A->type(), Expr(bias)));
},
tensor_name);
}
// Paddle upscale float16 or bfloat16 compute to float32,
// we made CINN consistent with this behavior of Paddle
bool should_upscale_fp32 =
A->type() == common::F16() || A->type() == common::BF16();
out = Compute(
A->shape,
[=](const std::vector<Expr> &indice) {
Expr cast_scale = should_upscale_fp32
? Expr(scale)
: ir::Cast::Make(A->type(), Expr(scale));
Expr cast_bias = should_upscale_fp32
? Expr(bias)
: ir::Cast::Make(A->type(), Expr(bias));
Expr cast_A_indice =
should_upscale_fp32 ? ir::Cast::Make(common::F32(), A(indice))
: A(indice);
Expr add_result = bias_after_scale
? cast_scale * cast_A_indice + cast_bias
: cast_scale * (cast_A_indice + cast_bias);
return should_upscale_fp32 ? ir::Cast::Make(A->type(), add_result)
: add_result;
},
tensor_name);
auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});
......@@ -413,6 +429,11 @@ std::vector<Type> InferDtypeForFillConstant(
return {out_type};
}
void GenerateEquationsForFillConstant(
cinn::adt::config::OpEquationContext *ctx) {
// Do nothing
}
std::vector<std::vector<std::string>> InferLayoutForFillConstant(
const std::vector<framework::shape_t> &input_shapes,
const std::vector<std::string> &input_layouts,
......@@ -987,6 +1008,9 @@ CINN_REGISTER_HELPER(elementwise_ops) {
MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) \
.set_attr("inferdtype", \
MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) \
.set_attr( \
"generate_equations", \
MakeOpFunction(cinn::hlir::op::GenerateEquationsForElementwise)) \
.set_attr("inferlayout", \
MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) \
.set_attr<cinn::hlir::framework::OpPatternKind>( \
......@@ -1108,6 +1132,9 @@ CINN_REGISTER_HELPER(elementwise_ops) {
MakeOpFunction(cinn::hlir::op::InferShapeForFillConstant))
.set_attr("inferdtype",
MakeOpFunction(cinn::hlir::op::InferDtypeForFillConstant))
.set_attr(
"generate_equations",
MakeOpFunction(cinn::hlir::op::GenerateEquationsForFillConstant))
#ifndef CINN_WITH_CUDA
.set_attr("inferlayout",
MakeOpFunction(cinn::hlir::op::InferLayoutForFillConstant))
......
......@@ -16,6 +16,7 @@
#include <functional>
#include "paddle/cinn/adt/op_equation_context.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
......@@ -78,6 +79,12 @@ std::vector<framework::shape_t> InferShapeForRelu(
return res;
}
void GenerateEquationsForRelu(cinn::adt::config::OpEquationContext *ctx) {
CHECK(ctx->GetInTensorsRanks().size() != 0)
<< "The inputs is empty! Please check again.";
ctx->Equal(ctx->GetInIteratorTuple(0), ctx->GetOutIteratorTuple(0));
}
std::vector<Type> InferDtypeForRelu(const std::vector<Type> &inputs_type,
const framework::AttrMapType &attrs) {
CHECK(!inputs_type.empty())
......@@ -2328,6 +2335,8 @@ CINN_REGISTER_HELPER(nn_ops) {
"CINNStrategy", cinn::hlir::op::StrategyForRelu)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForRelu))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForRelu))
.set_attr("generate_equations",
MakeOpFunction(cinn::hlir::op::GenerateEquationsForRelu))
#ifndef CINN_WITH_CUDA
.set_attr("inferlayout",
MakeOpFunction(cinn::hlir::op::InferLayoutForUnary))
......
......@@ -34,45 +34,21 @@ CINNSchedule GetElementwiseScheduleFunc(
common::CINNValuePack arg_pack = args[0];
CHECK_GT(arg_pack.size(), 0U)
<< "arg_pack.size() must contains at least one element.";
// TODO(Aurelius84): For NewIrCompiler, the outputs of Compute are
// tensor_ref and not Expr.
bool is_tensor_stages = arg_pack.size() == 2U && arg_pack[0].is_tensor() &&
arg_pack[1].is_stagemap();
if (!is_tensor_stages) {
std::vector<Expr> vec_ast;
for (int i = 0; i < arg_pack.size(); i++) {
if (arg_pack[i].is_expr()) {
Expr temp = arg_pack[i];
vec_ast.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
pe::IRElementwiseSchedule(ir_sch, output_shapes.front(), target);
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty()) << "The input argument of ElementwiseSchedule is "
"empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(out.as_tensor());
CHECK_EQ(arg_pack.size(), 2UL);
if (target.arch == Target::Arch::NVGPU) {
pe::CudaScheduleInjective(
stages[out.as_tensor_ref()], output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()],
output_shapes.front(),
target,
vectorizable);
std::vector<Expr> vec_ast;
for (int i = 0; i < arg_pack.size(); i++) {
if (arg_pack[i].is_expr()) {
Expr temp = arg_pack[i];
vec_ast.emplace_back(temp);
}
*ret = arg_pack;
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
pe::IRElementwiseSchedule(ir_sch, output_shapes.front(), target);
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
});
}
......
......@@ -18,6 +18,7 @@
#include <iostream>
#include <vector>
#include "paddle/cinn/adt/op_equation_context.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
......@@ -28,6 +29,11 @@
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/runtime/flags.h"
PD_DECLARE_bool(cinn_enable_map_expr);
PD_DECLARE_bool(cinn_new_group_scheduler);
namespace cinn {
namespace hlir {
......@@ -58,7 +64,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
const std::string &op_name,
BlockReduceFunc gpu_reduce_with_last_axis_func,
BlockReduceFunc gpu_reduce_without_last_axis_func,
ReduceFunc cpu_reduce_func) {
ReduceFunc common_reduce_func) {
std::vector<int> reduce_axes;
auto ndim = inputs[0]->shape.size();
if (attrs.attr_store.count("dim")) {
......@@ -127,7 +133,16 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
<< "The type of input argument " << x->name << " of " << op_name
<< " should be bool, but get " << x->type() << "! Please check.";
if (target == common::DefaultNVGPUTarget()) {
const auto &NaiveCompute = [&]() {
VLOG(3) << "Do Reduce Compute!";
auto out = common_reduce_func(x, reduce_axes, keep_dim, tensor_name);
auto stages = CreateStages({out});
std::vector<CINNValue> cinn_values{CINNValue(out), CINNValue(stages)};
*ret = CINNValuePack{cinn_values};
};
if (!FLAGS_cinn_enable_map_expr && !FLAGS_cinn_new_group_scheduler &&
target == common::DefaultNVGPUTarget()) {
if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) {
VLOG(3) << "Do Two Step Block Reduce Compute!";
auto res = gpu_reduce_with_last_axis_func(
......@@ -154,12 +169,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
*ret = CINNValuePack{cinn_values};
}
} else {
VLOG(3) << "Do Reduce Compute!";
auto out = cpu_reduce_func(x, reduce_axes, keep_dim, tensor_name);
auto stages = CreateStages({out});
std::vector<CINNValue> cinn_values{CINNValue(out), CINNValue(stages)};
*ret = CINNValuePack{cinn_values};
NaiveCompute();
}
});
......@@ -193,7 +203,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
if (target.arch == Target::Arch::NVGPU) {
if (!FLAGS_cinn_new_group_scheduler && target.arch == Target::Arch::NVGPU) {
if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) {
if (arg_pack.size() == 4) {
CHECK_EQ(vec_tensor.size(), 2);
......@@ -313,7 +323,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
reduce_op_, \
gpu_reduce_with_last_axis_func, \
gpu_reduce_without_last_axis_func, \
cpu_reduce_func) \
common_reduce_func) \
std::shared_ptr<OpStrategy> StrategyFor##reduce_op_( \
const framework::NodeAttr &attrs, \
const std::vector<ir::Tensor> &inputs, \
......@@ -328,7 +338,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
#op_name_, \
gpu_reduce_with_last_axis_func, \
gpu_reduce_without_last_axis_func, \
cpu_reduce_func); \
common_reduce_func); \
}
STRATEGY_FOR_REDUCE(reduce_sum,
......@@ -414,6 +424,35 @@ std::vector<shape_t> InferShapeForReduction(
return {out_shapes};
}
void GenerateEquationsForReduction(cinn::adt::config::OpEquationContext *ctx) {
CHECK(ctx->GetInTensorsRanks().size() != 0)
<< "The inputs is empty! Please check again.";
const bool keep_dim = ctx->Attr<bool>("keep_dim");
const auto &dim = ctx->Attr<std::vector<int>>("dim");
const auto &IsReduceAxis = [&](const int in_axis) {
return std::find(dim.begin(), dim.end(), in_axis) != dim.end();
};
const auto &VisitEachAxisPair = [&](const auto &DoEach) {
std::size_t out_axis = 0;
for (std::size_t in_axis = 0; in_axis < ctx->GetInTensorsRanks().at(0);
++in_axis) {
if (IsReduceAxis(in_axis)) {
out_axis += keep_dim;
} else {
DoEach(in_axis, out_axis);
out_axis += 1;
}
}
};
VisitEachAxisPair([&](const int input_axis, const int output_axis) {
ctx->Equal(ctx->GetInIteratorTuple(0)->at(input_axis),
ctx->GetOutIteratorTuple(0)->at(output_axis));
});
}
std::vector<Type> InferDtypeForReduction(const std::vector<Type> &inputs_type,
const framework::AttrMapType &attrs) {
CHECK(!inputs_type.empty())
......@@ -477,22 +516,24 @@ std::vector<std::vector<std::string>> InferLayoutForBnOptimize(
} // namespace cinn
CINN_REGISTER_HELPER(reduce_ops) {
#define CINN_REGISTER_REDUCTION_WITH_DTYPE(op__, op_stragegy__, dtype__) \
CINN_REGISTER_OP(op__) \
.describe(#op__ " function") \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr<cinn::hlir::framework::StrategyFunction>( \
"CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \
.set_attr("infershape", \
MakeOpFunction(cinn::hlir::op::InferShapeForReduction)) \
.set_attr( \
"inferdtype", \
MakeOpFunction(cinn::hlir::op::InferDtypeForReduction##dtype__)) \
.set_attr("inferlayout", \
MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) \
.set_attr<cinn::hlir::framework::OpPatternKind>( \
"OpPattern", cinn::hlir::framework::OpPatternKind::kReduction) \
#define CINN_REGISTER_REDUCTION_WITH_DTYPE(op__, op_stragegy__, dtype__) \
CINN_REGISTER_OP(op__) \
.describe(#op__ " function") \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr<cinn::hlir::framework::StrategyFunction>( \
"CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \
.set_attr("infershape", \
MakeOpFunction(cinn::hlir::op::InferShapeForReduction)) \
.set_attr( \
"inferdtype", \
MakeOpFunction(cinn::hlir::op::InferDtypeForReduction##dtype__)) \
.set_attr("generate_equations", \
MakeOpFunction(cinn::hlir::op::GenerateEquationsForReduction)) \
.set_attr("inferlayout", \
MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) \
.set_attr<cinn::hlir::framework::OpPatternKind>( \
"OpPattern", cinn::hlir::framework::OpPatternKind::kReduction) \
.set_support_level(4);
#define CINN_REGISTER_REDUCTION(op__, op_stragegy__) \
......
......@@ -39,6 +39,9 @@
#include "paddle/cinn/hlir/pe/nn.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
#include "paddle/cinn/runtime/cuda/cuda_module.h"
PD_DECLARE_bool(cinn_new_group_scheduler);
namespace cinn {
namespace hlir {
namespace framework {
......@@ -362,6 +365,9 @@ void TestCaseForReduce(const float init_val,
dim3 block;
grid = {c, 1, 1};
int block_dim_x = n * w * h > 1024 ? 1024 : n * w * h;
if (FLAGS_cinn_new_group_scheduler) {
block_dim_x = 1;
}
block = {block_dim_x, 1, 1};
void* args[] = {&dev_x, &dev_z};
......@@ -531,7 +537,8 @@ TEST(Operator, Operator_Reduction_Case_Warp_Reduce) {
std::vector<int> dim = {1};
auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce");
CHECK(res.second.find("threadIdx.x < 32") != std::string::npos);
if (!FLAGS_cinn_new_group_scheduler)
CHECK(res.second.find("threadIdx.x < 32") != std::string::npos);
}
TEST(Operator, Operator_Reduction_Case_Block_Reduce) {
......@@ -544,7 +551,8 @@ TEST(Operator, Operator_Reduction_Case_Block_Reduce) {
std::vector<int> dim = {1};
auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce");
CHECK(res.second.find("threadIdx.x < 32") == std::string::npos);
if (!FLAGS_cinn_new_group_scheduler)
CHECK(res.second.find("threadIdx.x < 32") == std::string::npos);
}
TEST(Operator, Operator_Reduction_Case_Warp_Reduce_Case_1) {
......@@ -558,7 +566,8 @@ TEST(Operator, Operator_Reduction_Case_Warp_Reduce_Case_1) {
auto res =
GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce_Case_1");
CHECK(res.second.find("threadIdx.x < 32") != std::string::npos);
if (!FLAGS_cinn_new_group_scheduler)
CHECK(res.second.find("threadIdx.x < 32") != std::string::npos);
}
TEST(Operator, Operator_Reduction_Case_Block_Reduce_Case_1) {
......@@ -572,7 +581,8 @@ TEST(Operator, Operator_Reduction_Case_Block_Reduce_Case_1) {
auto res =
GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce_Case_2");
CHECK(res.second.find("threadIdx.x < 32") == std::string::npos);
if (!FLAGS_cinn_new_group_scheduler)
CHECK(res.second.find("threadIdx.x < 32") == std::string::npos);
}
} // namespace framework
} // namespace hlir
......
......@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/pe/ir_schedule_pe.h"
#include "paddle/cinn/hlir/pe/nn.h"
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......@@ -2044,7 +2044,7 @@ CINN_REGISTER_HELPER(transform_ops) {
// pointers, the code generated by operator fusion will have out-of-bounds
// access. It should not fuse with any other injective operators, though
// scatter_add is injective. turn KNonFusible to kInjective will fail
// /Paddle/python/paddle/fluid/tests/unittests/test_index_select_op.py
// /Paddle/python/paddle/base/tests/unittests/test_index_select_op.py
.set_attr<cinn::hlir::framework::OpPatternKind>(
"OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible)
.set_support_level(4);
......
......@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h"
DEFINE_string(model_dir, "", "");
PD_DEFINE_string(model_dir, "", "");
namespace cinn {
namespace frontend {
......@@ -76,7 +76,8 @@ TEST(conv, conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -122,7 +123,8 @@ TEST(conv_relu_conv, conv_relu_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -171,7 +173,8 @@ TEST(conv_add_conv, conv_add_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -227,7 +230,8 @@ TEST(conv_bn_conv, conv_bn_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -283,7 +287,8 @@ TEST(conv_pool2d_conv, conv_pool2d_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -334,7 +339,8 @@ TEST(conv_softmax_conv, conv_softmax_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -382,7 +388,8 @@ TEST(conv_sigmoid_conv, conv_sigmoid_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -434,7 +441,8 @@ TEST(conv_mul_conv, conv_mul_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......
......@@ -46,7 +46,8 @@ void RunTest(const Target& target,
const std::shared_ptr<Graph>& graph,
const std::vector<std::string>& input_names) {
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
for (size_t i = 0; i < input_names.size(); ++i) {
scope->Var<hlir::framework::Tensor>(input_names[i]);
......
......@@ -37,7 +37,7 @@
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/utils/data_util.h"
DEFINE_string(model_dir, "", "");
PD_DEFINE_string(model_dir, "", "");
namespace cinn {
namespace frontend {
......@@ -71,7 +71,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) {
hlir::framework::ApplyPass(graph.get(), "BuildNonFusedGroupsPass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
......@@ -115,7 +116,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) {
hlir::framework::ApplyPass(graph.get(), "BuildNonFusedGroupsPass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
......@@ -180,7 +182,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) {
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
......
......@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h"
DEFINE_string(model_dir, "", "");
PD_DEFINE_string(model_dir, "", "");
namespace cinn {
namespace frontend {
......@@ -57,7 +57,8 @@ TEST(const_conv, const_conv) {
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
......@@ -101,7 +102,8 @@ TEST(const_bn, const_bn) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
......
......@@ -46,7 +46,8 @@ std::unordered_map<std::string, std::vector<float>> RunModelTest(
hlir::framework::ApplyPasses(graph.get(), passes);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (auto& data : input_data) {
......
......@@ -17,7 +17,8 @@
#include "paddle/cinn/hlir/op/external_api_registry.h"
#include "paddle/cinn/utils/string.h"
DECLARE_string(cinn_custom_call_deny_ops);
PD_DECLARE_string(cinn_custom_call_deny_ops);
PD_DECLARE_bool(cinn_use_cutlass);
namespace cinn {
namespace hlir {
......@@ -72,8 +73,10 @@ class GraphAlterHelper {
}
}
node->attrs.attr_store["original_op"] = node->op()->name;
node->attrs.op = framework::Operator::Get("custom_call");
if (!FLAGS_cinn_use_cutlass || node->op()->name != "matmul") {
node->attrs.attr_store["original_op"] = node->op()->name;
node->attrs.op = framework::Operator::Get("custom_call");
}
}
}
......
......@@ -46,7 +46,8 @@ void RunModelTest(Program& program, // NOLINT
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (int idx = 0; idx < inputs.size(); ++idx) {
......@@ -72,7 +73,8 @@ void RunModelTest(Program& program, // NOLINT
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (int idx = 0; idx < inputs.size(); ++idx) {
......
......@@ -45,7 +45,8 @@ void RunModelTest(Program& program, // NOLINT
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (int idx = 0; idx < inputs.size(); ++idx) {
......@@ -71,7 +72,8 @@ void RunModelTest(Program& program, // NOLINT
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (int idx = 0; idx < inputs.size(); ++idx) {
......
......@@ -14,7 +14,7 @@
#include "paddle/cinn/hlir/pass/fusion_merge_pass_util.h"
DECLARE_bool(enhance_vertical_fusion_with_recompute);
PD_DECLARE_bool(enhance_vertical_fusion_with_recompute);
namespace cinn {
namespace hlir {
......
......@@ -26,7 +26,7 @@
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass_ctx.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass_utils.h"
DECLARE_bool(enhance_vertical_fusion_with_recompute);
PD_DECLARE_bool(enhance_vertical_fusion_with_recompute);
namespace cinn {
namespace hlir {
......
......@@ -52,11 +52,11 @@ class FusionPassRegistrar final : public Registrar {
#define CINN_REGISTER_FUSION_PASS(pass_name, pass_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_pass__##pass_name, \
__reg_cinn_fusion_pass__##pass_name, \
"CINN_REGISTER_FUSION_PASS must be called in global namespace"); \
static ::cinn::hlir::pass::FusionPassRegistrar<pass_class> \
__pass_registrar_##pass_name##__(#pass_name); \
int TouchFusionPassRegistrar_##pass_name() { \
__pass_registrar_##pass_name##__.Touch(); \
__cinn_fusion_pass_registrar_##pass_name##__(#pass_name); \
int TouchCinnFusionPassRegistrar_##pass_name() { \
__cinn_fusion_pass_registrar_##pass_name##__.Touch(); \
return 0; \
}
......@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h"
DEFINE_string(model_dir, "", "");
PD_DEFINE_string(model_dir, "", "");
namespace cinn {
namespace frontend {
......@@ -80,7 +80,8 @@ TEST(complex2, complex2) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -135,7 +136,8 @@ TEST(complex1, complex1) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -172,7 +174,8 @@ TEST(fuse_add_relu, fuse_add_relu) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -210,7 +213,8 @@ TEST(fuse_add, fuse_add) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -268,7 +272,8 @@ TEST(conv_bn_conv, conv_bn_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -319,7 +324,8 @@ TEST(fuse_conv_add, fuse_conv_add) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -377,7 +383,8 @@ TEST(conv_add_mul, conv_add_mul) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -426,7 +433,8 @@ TEST(fuse_conv_add1, fuse_conv_add1) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -465,7 +473,8 @@ TEST(transpose_reshape_concat, transpose_reshape_concat) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -517,7 +526,8 @@ TEST(conv_bn, conv_bn) {
hlir::framework::ApplyPass(graph.get(), "OpFusion");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......
......@@ -30,7 +30,8 @@ std::unordered_map<std::string, std::vector<float>> RunModelTest(
hlir::framework::ApplyPasses(graph.get(), passes);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (auto& data : input_data) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment