Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
add_subdirectory(pir)
core_gather_headers() core_gather_headers()
gather_srcs( gather_srcs(
...@@ -12,22 +13,21 @@ gather_srcs( ...@@ -12,22 +13,21 @@ gather_srcs(
program.cc program.cc
parallel_compiler.cc parallel_compiler.cc
graph_compiler.cc graph_compiler.cc
graph_compiler_util.cc
graph.cc graph.cc
node.cc node.cc
pass.cc pass.cc
op_strategy.cc op_strategy.cc
op_lowering.cc
op_lowering_util.cc op_lowering_util.cc
op_lowering_impl.cc
accuracy_checker.cc accuracy_checker.cc
visualize_helper.cc) visualize_helper.cc
compile_error.cc)
# TODO(Aurelius84): new_ir_compiler depends on pd_dialect and could # TODO(Aurelius84): pir_compiler depends on op_dialect_vjp and could
# not found under CINN_ONLY mode # not found under CINN_ONLY mode
if(NOT CINN_ONLY) if(NOT CINN_ONLY)
cinn_cc_library(new_ir_compiler SRCS new_ir_compiler.cc DEPS cinnapi cinn_cc_library(pir_compiler SRCS pir_compiler.cc DEPS cinnapi op_dialect_vjp)
pd_dialect)
cinn_cc_library(convert_to_dialect SRCS convert_to_dialect.cc DEPS cinnapi
cinn_dialect)
endif() endif()
if(WITH_CUDA) if(WITH_CUDA)
...@@ -52,5 +52,5 @@ cinn_cc_test(test_hlir_framework_op SRCS op_test.cc DEPS cinncore) ...@@ -52,5 +52,5 @@ cinn_cc_test(test_hlir_framework_op SRCS op_test.cc DEPS cinncore)
cinn_cc_test(test_hlir_framework_print_graph_pass SRCS print_graph_pass_test.cc cinn_cc_test(test_hlir_framework_print_graph_pass SRCS print_graph_pass_test.cc
DEPS cinncore) DEPS cinncore)
cinn_cc_test(test_hlir_framework_graph SRCS graph_test.cc DEPS cinncore) cinn_cc_test(test_hlir_framework_graph SRCS graph_test.cc DEPS cinncore)
cinn_cc_test(test_hlir_framework_graph_compiler SRCS graph_compiler_test.cc
#cinn_cc_test(test_hlir_framework_graph_compiler SRCS graph_compiler_test.cc DEPS cinncore) DEPS cinncore)
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#endif #endif
DECLARE_int64(cinn_self_check_accuracy_num); PD_DECLARE_int64(cinn_self_check_accuracy_num);
namespace cinn { namespace cinn {
namespace hlir { namespace hlir {
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/framework/instruction.h" #include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h" #include "paddle/cinn/hlir/framework/op_strategy.h"
DECLARE_string(cinn_self_check_accuracy); PD_DECLARE_string(cinn_self_check_accuracy);
namespace cinn { namespace cinn {
namespace hlir { namespace hlir {
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/utils/enum_string.h"
namespace cinn {
namespace hlir {
namespace framework {
std::string CompileErrorHandler::GeneralErrorMessage() const {
std::ostringstream os;
os << "[CompileError] An error occurred during compilation with the error "
"code: "
<< utils::Enum2String(status_) << std::endl;
os << "(at " << file_ << " : " << line_ << ")" << std::endl;
os << indent_str_ << "[Error info] " << this->err_msg_ << std::endl;
return os.str();
}
std::string CompileErrorHandler::DetailedErrorMessage() const {
std::ostringstream os;
os << GeneralErrorMessage();
os << indent_str_ << "[Detail info] " << detail_info_ << std::endl;
return os.str();
}
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/utils/error.h"
namespace cinn {
namespace hlir {
namespace framework {
/**
* This handler is used to deal with the errors during the compilation process
*/
class CompileErrorHandler : public utils::ErrorHandler {
public:
/**
* \brief constructor
* \param err_msg the error message
*/
explicit CompileErrorHandler(const CompilationStatus& status,
const std::string& err_msg,
const std::string& detail_info,
const char* file,
int line)
: status_(status),
err_msg_(err_msg),
detail_info_(detail_info),
file_(file),
line_(line) {}
/**
* \brief Returns a short error message corresponding to the kGeneral error
* level.
*/
std::string GeneralErrorMessage() const;
/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std::string DetailedErrorMessage() const;
CompilationStatus Status() const { return status_; }
private:
CompilationStatus status_;
std::string err_msg_;
std::string detail_info_;
const char* file_;
int line_;
};
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/convert_to_dialect.h"
#include <string>
#include <unordered_map>
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
#include "paddle/cinn/hlir/dialect/runtime_dialect.h"
#include "paddle/cinn/hlir/framework/program.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/program.h"
namespace cinn {
namespace hlir {
namespace framework {
std::unique_ptr<::ir::Program> ConvertToRuntimeDialect(
const hlir::framework::Program& program) {
::ir::IrContext* ctx = ::ir::IrContext::Instance();
ctx->GetOrRegisterDialect<cinn::dialect::RuntimeDialect>();
auto ir_program = std::make_unique<::ir::Program>(ctx);
std::string jit_op_name = dialect::JitKernelOp::name();
::ir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name);
auto& instrs = program.GetRunInstructions();
for (auto& instr : instrs) {
std::unordered_map<std::string, ::ir::Attribute> op_attrs{
{dialect::JitKernelOp::kAttrName,
::ir::PointerAttribute::get(ctx, instr.get())},
};
::ir::Operation* cinn_op =
::ir::Operation::Create({}, op_attrs, {}, op_info);
ir_program->block()->push_back(cinn_op);
}
return std::move(ir_program);
}
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
namespace ir {
class Program;
} // namespace ir
namespace cinn {
namespace hlir {
namespace framework {
class Program;
std::unique_ptr<::ir::Program> ConvertToRuntimeDialect(
const hlir::framework::Program& program);
} // namespace framework
} // namespace hlir
} // namespace cinn
...@@ -18,10 +18,14 @@ ...@@ -18,10 +18,14 @@
#include <sstream> #include <sstream>
#include "paddle/cinn/hlir/framework/visualize_helper.h" #include "paddle/cinn/hlir/framework/visualize_helper.h"
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/runtime/cuda/cuda_util.h"
#endif
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/string.h"
DECLARE_string(cinn_fusion_groups_graphviz_dir); PD_DECLARE_string(cinn_fusion_groups_graphviz_dir);
namespace cinn { namespace cinn {
namespace hlir { namespace hlir {
...@@ -309,15 +313,20 @@ void Graph::VisualizeGroupedGraph( ...@@ -309,15 +313,20 @@ void Graph::VisualizeGroupedGraph(
} }
// Dump debug info for each group // Dump debug info for each group
LOG(INFO) << "Dump graph debug info to: " VLOG(4) << "Dump graph debug info to: "
<< FLAGS_cinn_fusion_groups_graphviz_dir; << FLAGS_cinn_fusion_groups_graphviz_dir;
const auto& groups = RemoveAccCheckGroups(origin_groups); const auto& groups = RemoveAccCheckGroups(origin_groups);
const auto& group_dots = VisualizeGroups(groups, fetch_var_ids); const auto& group_dots = VisualizeGroups(groups, fetch_var_ids);
for (int idx = 0; idx < groups.size(); ++idx) { for (int idx = 0; idx < groups.size(); ++idx) {
// Create fusion_group_x folder // Create fusion_group_x folder
int device_id = 0;
#ifdef CINN_WITH_CUDA
cudaGetDevice(&device_id);
#endif
auto group_path = auto group_path =
utils::StringFormat("%s/fusion_group_%d", utils::StringFormat("%s/device_%d/fusion_group_%d",
FLAGS_cinn_fusion_groups_graphviz_dir.c_str(), FLAGS_cinn_fusion_groups_graphviz_dir.c_str(),
device_id,
idx); idx);
if (!MakeDirectory(group_path, if (!MakeDirectory(group_path,
S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) {
...@@ -468,7 +477,7 @@ std::vector<std::string> Graph::VisualizeGroups( ...@@ -468,7 +477,7 @@ std::vector<std::string> Graph::VisualizeGroups(
return dot_vec; return dot_vec;
} }
std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() { std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() const {
std::unordered_set<NodeData*> group_inputs; std::unordered_set<NodeData*> group_inputs;
// count all node's input data // count all node's input data
...@@ -498,7 +507,7 @@ std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() { ...@@ -498,7 +507,7 @@ std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() {
return group_inputs; return group_inputs;
} }
std::unordered_set<NodeData*> Graph::Group::GetOutputNodeDatas() { std::unordered_set<NodeData*> Graph::Group::GetOutputNodeDatas() const {
std::unordered_set<NodeData*> group_outputs; std::unordered_set<NodeData*> group_outputs;
for (auto node : this->output_nodes) { for (auto node : this->output_nodes) {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/hlir/framework/node.h"
namespace cinn { namespace cinn {
namespace hlir { namespace hlir {
namespace framework { namespace framework {
...@@ -59,6 +60,8 @@ class Graph : public cinn::common::Graph { ...@@ -59,6 +60,8 @@ class Graph : public cinn::common::Graph {
std::vector<std::vector<Node*>> groups; std::vector<std::vector<Node*>> groups;
struct Group { struct Group {
Group() = default; Group() = default;
Group(const Group&) = delete;
Group(Group&&) = delete;
explicit Group(const Graph* graph) : graph_(graph) {} explicit Group(const Graph* graph) : graph_(graph) {}
...@@ -109,7 +112,7 @@ class Graph : public cinn::common::Graph { ...@@ -109,7 +112,7 @@ class Graph : public cinn::common::Graph {
} }
}; };
std::vector<Node*> CollectNodes() { std::vector<Node*> CollectNodes() const {
if (fused_sub_groups.size()) { if (fused_sub_groups.size()) {
std::vector<Node*> tmp_nodes; std::vector<Node*> tmp_nodes;
for (auto& group : fused_sub_groups) { for (auto& group : fused_sub_groups) {
...@@ -144,8 +147,8 @@ class Graph : public cinn::common::Graph { ...@@ -144,8 +147,8 @@ class Graph : public cinn::common::Graph {
return node_set; return node_set;
} }
std::unordered_set<NodeData*> GetInputNodeDatas(); std::unordered_set<NodeData*> GetInputNodeDatas() const;
std::unordered_set<NodeData*> GetOutputNodeDatas(); std::unordered_set<NodeData*> GetOutputNodeDatas() const;
std::string GetFuncName() { return "fn_" + group_id + unique_id; } std::string GetFuncName() { return "fn_" + group_id + unique_id; }
......
...@@ -29,8 +29,11 @@ ...@@ -29,8 +29,11 @@
#include "paddle/cinn/lang/lower.h" #include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h" #include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/poly/stage.h" #include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/utils/enum_string.h"
#include "paddle/cinn/utils/profiler.h" #include "paddle/cinn/utils/profiler.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
namespace cinn { namespace cinn {
namespace hlir { namespace hlir {
namespace framework { namespace framework {
...@@ -40,90 +43,124 @@ using cinn::common::float16; ...@@ -40,90 +43,124 @@ using cinn::common::float16;
std::unique_ptr<Program> GraphCompiler::Build(const std::string& code) { std::unique_ptr<Program> GraphCompiler::Build(const std::string& code) {
utils::RecordEvent("GraphCompiler::Build", utils::EventType::kGraph); utils::RecordEvent("GraphCompiler::Build", utils::EventType::kGraph);
GraphCompiler::CompileOptions options; compilation_context_.ApplySourceCode(code);
options.attached_code = code; compilation_context_.with_instantiate_variables = true;
options.with_instantiate_variables = true;
auto&& result = Build(options);
return std::move(result.runtime_program);
}
void GraphCompiler::CompileOptions::Apply( auto&& result = Build(&compilation_context_);
const auto_schedule::TuningResult& tuning_result) { return result.RuntimeProgram();
// assign options with TuningResult directly
groups.assign(tuning_result.subgraphs.begin(), tuning_result.subgraphs.end());
lowered_funcs.assign(tuning_result.function_groups.begin(),
tuning_result.function_groups.end());
} }
GraphCompiler::CompilationResult GraphCompiler::Build( CompilationResult GraphCompiler::Build(CompilationContext* context) {
const GraphCompiler::CompileOptions& options,
std::unordered_set<std::string>&& fetch_var_ids,
void* stream) {
Context::Global().ResetNameId(); Context::Global().ResetNameId();
// write group's information into FLAGS_cinn_fusion_groups_graphviz_dir // write group's information into FLAGS_cinn_fusion_groups_graphviz_dir
graph_->VisualizeGroupedGraph(fetch_var_ids.empty() ? fetch_var_ids_ context->graph->VisualizeGroupedGraph(context->fetch_var_ids);
: fetch_var_ids);
if (options.with_instantiate_variables) { if (context->with_instantiate_variables) {
InstantiateVariables(); InstantiateVariables(context);
} }
VLOG(2) << "Compile With Parallel Compiler!"; VLOG(2) << "Compile With Parallel Compiler!";
utils::RecordEvent("GraphCompiler CompileResult", utils::RecordEvent("GraphCompiler CompileResult",
utils::EventType::kOrdinary); utils::EventType::kOrdinary);
ParallelCompiler::CompileOptions option;
option.lowered_funcs = options.lowered_funcs;
parallel_compiler_ = parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
std::make_shared<ParallelCompiler>(scope_, graph_, option, target_); CompilationResult result = (*parallel_compiler_.get())();
auto result = (*parallel_compiler_.get())();
// Dump compilation result if (context->stage != CompilationStage::DEFAULT || !result.IsSuccess()) {
backends::CompilationInfoDumper dumper(result); return result;
}
if (options.remove_unused_variables) { if (context->remove_unused_variables) {
RemoveInvalidVariables(result.instructions); RemoveInvalidVariables(context, result.RuntimeInstructions());
} }
if (options.with_buffer_handle_instruction_inserted) { if (context->with_buffer_handle_instruction_inserted) {
VLOG(3) << "option.with_buffer_handle_instruction_inserted enable"; VLOG(3) << "option.with_buffer_handle_instruction_inserted enable";
InsertBufferHandlers(&result.instructions); InsertBufferHandlers(context, &result.instructions_);
} }
VLOG(2) << "Compile With Parallel Compiler Done!"; VLOG(2) << "Compile With Parallel Compiler Done!";
GraphCompiler::CompilationResult compilation_result; result.SetRuntimeProgram(std::make_unique<Program>(
compilation_result.runtime_program.reset( context->scope, std::move(result.instructions_)));
new Program(scope_, std::move(result.instructions))); return result;
return compilation_result; }
CompilationResult GraphCompiler::Lowering() {
return Lowering(&compilation_context_);
} }
void GraphCompiler::InstantiateVariables() { CompilationResult GraphCompiler::Lowering(CompilationContext* context) {
// Global setting
Context::Global().ResetNameId();
// Setting compile options
VLOG(2) << "Compile With Parallel Compiler! But just lowering!";
context->stage = CompilationStage::LOWERING;
// Compile with parallel compiler
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
return result;
}
CompilationResult GraphCompiler::CodegenAndJit() {
return CodegenAndJit(&compilation_context_);
}
CompilationResult GraphCompiler::CodegenAndJit(CompilationContext* context) {
// Global setting
Context::Global().ResetNameId();
// Setting compile options
VLOG(2) << "Compile With Parallel Compiler! But just codegen and jit!";
context->stage = CompilationStage::CODEGEN_AND_JIT;
// Compile with parallel compiler
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
return result;
}
CompilationResult GraphCompiler::BuildInstruction() {
return BuildInstruction(&compilation_context_);
}
CompilationResult GraphCompiler::BuildInstruction(CompilationContext* context) {
// Global setting
Context::Global().ResetNameId();
// Setting compile options
VLOG(2) << "Compile With Parallel Compiler! But just build instruction!";
context->stage = CompilationStage::BUILD_INSTRUCTION;
// Compile with parallel compiler
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
return result;
}
void GraphCompiler::InstantiateVariables(CompilationContext* context) {
VLOG(3) << "Instantiate all variables on compile-time"; VLOG(3) << "Instantiate all variables on compile-time";
utils::RecordEvent("GraphCompiler MutableData", utils::EventType::kOrdinary); utils::RecordEvent("GraphCompiler MutableData", utils::EventType::kOrdinary);
// All variables reside in scope_, so traverse it to instantiate each one // All variables reside in scope_, so traverse it to instantiate each one
for (auto& name : scope_->var_names()) { for (auto& name : context->scope->var_names()) {
auto* var = scope_->Var<Tensor>(std::string({name.data(), name.size()})); auto* var =
context->scope->Var<Tensor>(std::string({name.data(), name.size()}));
auto& tensor = absl::get<Tensor>(*var); auto& tensor = absl::get<Tensor>(*var);
if (reuse_vars_map_.count(name)) { if (context->reuse_vars_map.count(name)) {
auto src_var_name = reuse_vars_map_.at(name); auto src_var_name = context->reuse_vars_map.at(name);
auto* src_var = scope_->Var<Tensor>(src_var_name); auto* src_var = context->scope->Var<Tensor>(src_var_name);
auto& src_tensor = absl::get<Tensor>(*src_var); auto& src_tensor = absl::get<Tensor>(*src_var);
tensor->set_buffer(src_tensor->get_buffer()); tensor->set_buffer(src_tensor->get_buffer());
} else { } else {
tensor->mutable_data(target_, tensor->type()); tensor->mutable_data(context->target, tensor->type());
} }
} }
} }
void GraphCompiler::RemoveInvalidVariables( void GraphCompiler::RemoveInvalidVariables(
CompilationContext* context,
const std::vector<std::unique_ptr<Instruction>>& instructions) { const std::vector<std::unique_ptr<Instruction>>& instructions) {
// mark all variables are invalid initially // mark all variables are invalid initially
utils::RecordEvent("GraphCompiler RemoveInvalidVariables", utils::RecordEvent("GraphCompiler RemoveInvalidVariables",
utils::EventType::kOrdinary); utils::EventType::kOrdinary);
std::unordered_set<std::string> invalid_variables; std::unordered_set<std::string> invalid_variables;
auto var_names = scope_->var_names(); auto var_names = context->scope->var_names();
invalid_variables.reserve(var_names.size()); invalid_variables.reserve(var_names.size());
std::transform( std::transform(
var_names.begin(), var_names.begin(),
...@@ -162,8 +199,8 @@ void GraphCompiler::RemoveInvalidVariables( ...@@ -162,8 +199,8 @@ void GraphCompiler::RemoveInvalidVariables(
<< " invalid variables to be removed from scope"; << " invalid variables to be removed from scope";
std::for_each(invalid_variables.begin(), std::for_each(invalid_variables.begin(),
invalid_variables.end(), invalid_variables.end(),
[this](const std::string& var_name) { [context](const std::string& var_name) {
scope_->EraseVar(var_name); context->scope->EraseVar(var_name);
VLOG(3) << "Variable(" << var_name << ") is erased"; VLOG(3) << "Variable(" << var_name << ") is erased";
}); });
} }
...@@ -222,6 +259,7 @@ void GraphCompiler::AnalyzeVariableLifeTime( ...@@ -222,6 +259,7 @@ void GraphCompiler::AnalyzeVariableLifeTime(
} }
void GraphCompiler::InsertBufferHandlers( void GraphCompiler::InsertBufferHandlers(
CompilationContext* context,
std::vector<std::unique_ptr<Instruction>>* instructions) { std::vector<std::unique_ptr<Instruction>>* instructions) {
utils::RecordEvent("GraphCompiler InsertBufferHandlers", utils::RecordEvent("GraphCompiler InsertBufferHandlers",
utils::EventType::kOrdinary); utils::EventType::kOrdinary);
...@@ -240,7 +278,7 @@ void GraphCompiler::InsertBufferHandlers( ...@@ -240,7 +278,7 @@ void GraphCompiler::InsertBufferHandlers(
auto function_name = "malloc_buffer_instruction_" + std::to_string(step); auto function_name = "malloc_buffer_instruction_" + std::to_string(step);
auto malloc_instr = auto malloc_instr =
std::make_unique<Instruction>(common::DefaultHostTarget(), std::make_unique<Instruction>(common::DefaultHostTarget(),
scope_.get(), context->scope.get(),
malloc_var_names, malloc_var_names,
std::vector<std::string>({}), std::vector<std::string>({}),
function_name); function_name);
...@@ -263,7 +301,7 @@ void GraphCompiler::InsertBufferHandlers( ...@@ -263,7 +301,7 @@ void GraphCompiler::InsertBufferHandlers(
auto function_name = "free_buffer_instruction_" + std::to_string(step); auto function_name = "free_buffer_instruction_" + std::to_string(step);
auto free_instr = auto free_instr =
std::make_unique<Instruction>(common::DefaultHostTarget(), std::make_unique<Instruction>(common::DefaultHostTarget(),
scope_.get(), context->scope.get(),
std::vector<std::string>({}), std::vector<std::string>({}),
free_var_names, free_var_names,
function_name); function_name);
...@@ -336,14 +374,17 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl( ...@@ -336,14 +374,17 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
poly::StageMap stages = C.back(); poly::StageMap stages = C.back();
std::string func_name_prefix = "fn_"; std::string func_name_prefix = "fn_";
auto funcs = lang::LowerVec(func_name_prefix + node_id,
stages, ast_gen_ius::TensorGroup tensor_group =
all_arg_tensors, ast_gen_ius::ConvertStageMapToTensorGroup(stages);
{}, auto funcs = lang::LowerToAstVec(
{}, func_name_prefix + node_id, all_arg_tensors, &tensor_group, target);
nullptr,
target, VLOG(4) << "Lower op: " << node_id << ", get " << funcs.size()
true); << " LoweredFunc:\n";
for (auto fun : funcs) {
VLOG(4) << fun;
}
std::vector<common::CINNValue> schedule_inputs; std::vector<common::CINNValue> schedule_inputs;
for (int i = 0; i < C.size() - 1; ++i) { for (int i = 0; i < C.size() - 1; ++i) {
...@@ -390,7 +431,8 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl( ...@@ -390,7 +431,8 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
optim::OptimizeExprGPU(&(funcs_after_schedule[i]->body)); optim::OptimizeExprGPU(&(funcs_after_schedule[i]->body));
#endif #endif
auto temp_buffers = lang::GetTempBuffers( auto temp_buffers = lang::GetTempBuffers(
all_arg_tensors, stages, funcs_after_schedule[i]->body); all_arg_tensors, tensor_group, funcs_after_schedule[i]->body);
funcs_after_schedule[i]->temp_bufs = temp_buffers; funcs_after_schedule[i]->temp_bufs = temp_buffers;
funcs_after_schedule[i] = funcs_after_schedule[i] =
ir::_LoweredFunc_::Make(funcs_after_schedule[i]->name, ir::_LoweredFunc_::Make(funcs_after_schedule[i]->name,
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/cinn/backends/cuda_util.h" #include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/macros.h" #include "paddle/cinn/common/macros.h"
#include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/instruction.h" #include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h" #include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/hlir/framework/parallel_compiler.h" #include "paddle/cinn/hlir/framework/parallel_compiler.h"
...@@ -46,48 +47,41 @@ namespace framework { ...@@ -46,48 +47,41 @@ namespace framework {
*/ */
class GraphCompiler final { class GraphCompiler final {
public: public:
GraphCompiler(Target target, GraphCompiler(CompilationContext context) : compilation_context_(context) {}
const std::shared_ptr<Scope>& scope,
const std::shared_ptr<Graph>& graph)
: target_(std::move(target)), scope_(scope), graph_(graph) {}
struct CompilationResult {
std::unique_ptr<Program> runtime_program;
};
struct CompileOptions {
std::string attached_code = "";
bool with_instantiate_variables = false;
bool with_buffer_handle_instruction_inserted = false;
bool remove_unused_variables = true;
// nodes group, it may come from the result of op fusion or graph tuning.
// nodes in a group will be built into an Instruction
std::vector<std::shared_ptr<Graph::Group>> groups;
// corresponding LoweredFuncs of above grouped nodes,
// if it is empty then graph_compiler will generate for them
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
// apply results of auto-tune to compile
void Apply(const auto_schedule::TuningResult& tuning_result);
};
// Compile with a packing option and result, to be extended easily. // Compile with a packing option and result, to be extended easily.
CompilationResult Build(const CompileOptions& options, CompilationResult Build(CompilationContext* context);
std::unordered_set<std::string>&& fetch_var_ids = {},
void* stream = nullptr);
std::unique_ptr<Program> Build(const std::string& code = ""); std::unique_ptr<Program> Build(const std::string& code = "");
const std::shared_ptr<Scope>& GetScope() const { return scope_; } CompilationResult Lowering();
CompilationResult Lowering(CompilationContext* context);
CompilationResult CodegenAndJit();
CompilationResult CodegenAndJit(CompilationContext* context);
CompilationResult BuildInstruction();
CompilationResult BuildInstruction(CompilationContext* context);
const std::shared_ptr<Scope>& GetScope() const {
return compilation_context_.scope;
}
CompilationContext& GetCompilationContext() { return compilation_context_; }
void SetCompilationContext(const CompilationContext& context) {
compilation_context_ = context;
}
private: private:
// instantiate all variables on compile time // instantiate all variables on compile time
void InstantiateVariables(); void InstantiateVariables(CompilationContext* context);
// some variables are eliminated by optimized passes(such as OpFusion), // some variables are eliminated by optimized passes(such as OpFusion),
// we can filter out them according to arguments of the built instructions, // we can filter out them according to arguments of the built instructions,
// and erase them from the scope to avoid unnecessary buffer allocation // and erase them from the scope to avoid unnecessary buffer allocation
void RemoveInvalidVariables( void RemoveInvalidVariables(
CompilationContext* context,
const std::vector<std::unique_ptr<Instruction>>& instructions); const std::vector<std::unique_ptr<Instruction>>& instructions);
// find the first and last instruction where a variable used, and mark the // find the first and last instruction where a variable used, and mark the
...@@ -102,21 +96,14 @@ class GraphCompiler final { ...@@ -102,21 +96,14 @@ class GraphCompiler final {
// firstly used in the next instruction, and insert a buffer free instruction // firstly used in the next instruction, and insert a buffer free instruction
// applying on variables after no instruction will use them anymore // applying on variables after no instruction will use them anymore
void InsertBufferHandlers( void InsertBufferHandlers(
CompilationContext* context,
std::vector<std::unique_ptr<Instruction>>* instructions); std::vector<std::unique_ptr<Instruction>>* instructions);
private: private:
// parallel compiler // parallel compiler
std::shared_ptr<ParallelCompiler> parallel_compiler_; std::shared_ptr<ParallelCompiler> parallel_compiler_;
Target target_; CompilationContext compilation_context_;
std::shared_ptr<Graph> graph_;
std::shared_ptr<Scope> scope_;
// fetch var ids in cinn and the corresponding var nodes will not be fused so
// as to get the result
std::unordered_set<std::string> fetch_var_ids_;
// map dst reuse var to the src var sharing buffer
absl::flat_hash_map<std::string, std::string> reuse_vars_map_;
CINN_DISALLOW_COPY_AND_ASSIGN(GraphCompiler); CINN_DISALLOW_COPY_AND_ASSIGN(GraphCompiler);
}; };
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/cinn/frontend/net_builder.h" #include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h" #include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/program_pass.h" #include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/scope.h" #include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
...@@ -48,7 +49,8 @@ TEST(GraphCompilerTest, TestRemoveInvaildVariables) { ...@@ -48,7 +49,8 @@ TEST(GraphCompilerTest, TestRemoveInvaildVariables) {
ASSERT_EQ(scope->var_names().size(), 6); ASSERT_EQ(scope->var_names().size(), 6);
EXPECT_NE(scope->FindVar(c->id), nullptr); EXPECT_NE(scope->FindVar(c->id), nullptr);
GraphCompiler gc(target, scope, graph); CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
ASSERT_EQ(scope->var_names().size(), 3); ASSERT_EQ(scope->var_names().size(), 3);
EXPECT_EQ(scope->FindVar(c->id), nullptr); EXPECT_EQ(scope->FindVar(c->id), nullptr);
...@@ -69,10 +71,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) { ...@@ -69,10 +71,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
auto graph = Optimize(&program, {}, target); auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
GraphCompiler gc_disable(target, scope, graph); CompilationContext context_disable(graph, scope, target);
GraphCompiler::CompileOptions options; GraphCompiler gc_disable(context_disable);
// disable with_buffer_handle_instruction_inserted: only 1 instruction // disable with_buffer_handle_instruction_inserted: only 1 instruction
auto runtime_program_disable = gc_disable.Build(options).runtime_program; auto runtime_program_disable =
gc_disable.Build(&context_disable).RuntimeProgram();
ASSERT_EQ(runtime_program_disable->size(), 1); ASSERT_EQ(runtime_program_disable->size(), 1);
const auto& computation_instr_disable = const auto& computation_instr_disable =
runtime_program_disable->GetRunInstructions().front(); runtime_program_disable->GetRunInstructions().front();
...@@ -80,9 +83,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) { ...@@ -80,9 +83,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
// enable with_buffer_handle_instruction_inserted: 3 instructions, 1st -> // enable with_buffer_handle_instruction_inserted: 3 instructions, 1st ->
// malloc instruction(a, b, d), 2nd -> the real computation // malloc instruction(a, b, d), 2nd -> the real computation
// instruction(add + relu) and 3rd -> free instruction // instruction(add + relu) and 3rd -> free instruction
GraphCompiler gc_enable(target, scope, graph); CompilationContext context_enable(graph, scope, target);
options.with_buffer_handle_instruction_inserted = true; context_enable.with_buffer_handle_instruction_inserted = true;
auto runtime_program_enable = gc_enable.Build(options).runtime_program; GraphCompiler gc_enable(context_enable);
auto runtime_program_enable =
gc_enable.Build(&context_enable).RuntimeProgram();
const auto& instructions = runtime_program_enable->GetRunInstructions(); const auto& instructions = runtime_program_enable->GetRunInstructions();
ASSERT_EQ(instructions.size(), 3); ASSERT_EQ(instructions.size(), 3);
...@@ -193,7 +198,8 @@ void RunCublas( ...@@ -193,7 +198,8 @@ void RunCublas(
hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
GraphCompiler gc(target, scope, graph); CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
auto exe_program = gc.Build(); auto exe_program = gc.Build();
auto data_a = scope->GetTensor("A"); auto data_a = scope->GetTensor("A");
...@@ -231,6 +237,66 @@ TEST(GraphCompilerTest, TestCublas) { ...@@ -231,6 +237,66 @@ TEST(GraphCompilerTest, TestCublas) {
RunCublas(64, 128, 128, true, true); RunCublas(64, 128, 128, true, true);
} }
TEST(GraphCompilerTest, TestLowering) {
frontend::NetBuilder builder("test_lowering_on_graph_compiler");
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
auto target = common::DefaultNVGPUTarget();
auto program = builder.Build();
auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
CompilationResult result = gc.Lowering();
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}
TEST(GraphCompilerTest, TestCodegenAndJit) {
frontend::NetBuilder builder("test_codegen_and_jit_on_graph_compiler");
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
auto target = common::DefaultNVGPUTarget();
auto program = builder.Build();
auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
CompilationResult result = gc.CodegenAndJit();
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}
TEST(GraphCompilerTest, TestBuildInstruction) {
frontend::NetBuilder builder("test_build_instruction_on_graph_compiler");
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
auto target = common::DefaultNVGPUTarget();
auto program = builder.Build();
auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
CompilationResult result = gc.BuildInstruction();
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}
#endif #endif
} // namespace framework } // namespace framework
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/utils/error.h"
namespace cinn {
namespace hlir {
namespace framework {
void CompilationContext::ApplyTuningResult(
const auto_schedule::TuningResult& tuning_result) {
// assign options with TuningResult directly
groups.assign(tuning_result.subgraphs.begin(), tuning_result.subgraphs.end());
lowered_funcs.assign(tuning_result.function_groups.begin(),
tuning_result.function_groups.end());
}
void CompilationContext::ApplySourceCode(const std::string& code) {
attached_source_code = code;
}
void CompilationResult::InitCompilationResult(int group_size) {
size_ = group_size;
status_.resize(group_size, CompilationStatus::SUCCESS);
messages_.resize(group_size);
for (int idx = 0; idx < group_size; ++idx) {
messages_[idx] =
"Group Idx: " + std::to_string(idx) + ", Compile Success.\n";
}
lowered_funcs_.resize(group_size, std::nullopt);
source_codes_.resize(group_size, std::nullopt);
source_ptxs_.resize(group_size, std::nullopt);
instructions_.resize(group_size);
}
void CompilationResult::SetStatus(int idx, const CompilationStatus& status) {
if (idx < status_.size()) {
status_[idx] = status;
}
}
void CompilationResult::SetMessage(int idx, const std::string& message) {
if (idx < messages_.size()) {
messages_[idx] = message;
}
}
void CompilationResult::SetLoweredFuncs(
int idx, const std::vector<ir::LoweredFunc>& funcs) {
if (idx < lowered_funcs_.size()) {
lowered_funcs_[idx] = funcs;
}
}
void CompilationResult::SetSourceCode(int idx, const std::string& source_code) {
if (idx < source_codes_.size()) {
source_codes_[idx] = source_code;
}
}
void CompilationResult::SetSourcePtx(int idx, const std::string& source_ptx) {
if (idx < source_ptxs_.size()) {
source_ptxs_[idx] = source_ptx;
}
}
void CompilationResult::SetInstruction(
int idx, std::unique_ptr<Instruction> instruction) {
if (idx < instructions_.size()) {
instructions_[idx] = std::move(instruction);
}
}
void CompilationResult::SetRuntimeProgram(
std::unique_ptr<Program> runtime_program) {
runtime_program_ = std::move(runtime_program);
}
bool CompilationResult::IsSuccess() const {
for (const CompilationStatus& s : status_) {
if (s != CompilationStatus::SUCCESS) {
return false;
}
}
return true;
}
CompilationStatus CompilationResult::Status() const {
CompilationStatus worst_status = CompilationStatus::SUCCESS;
for (const CompilationStatus& s : status_) {
if (s < worst_status) {
worst_status = s;
}
}
return worst_status;
}
CompilationStatus CompilationResult::Status(int idx) const {
if (idx >= status_.size()) {
return CompilationStatus::UNKNOWN_FAIL;
}
return status_[idx];
}
std::string CompilationResult::Message() const {
std::string res;
for (int idx = 0; idx < messages_.size(); ++idx) {
res += messages_[idx];
}
return res;
}
std::string CompilationResult::Message(int idx) const {
if (idx >= messages_.size()) {
std::stringstream ss;
ss << "The index(" << idx
<< ") is expected to be less than the size of group("
<< lowered_funcs_.size() << ").";
CINN_THROW(ss.str());
}
return messages_[idx];
}
std::vector<std::vector<ir::LoweredFunc>> CompilationResult::LoweredFuncs()
const {
std::vector<std::vector<ir::LoweredFunc>> res(lowered_funcs_.size());
for (int idx = 0; idx < lowered_funcs_.size(); ++idx) {
if (lowered_funcs_[idx].has_value()) {
res[idx] = lowered_funcs_[idx].value();
} else {
std::stringstream ss;
ss << "LoweredFuncs of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the lower "
"process.\n"
<< Message();
CINN_THROW(ss.str());
}
}
return res;
}
std::vector<ir::LoweredFunc> CompilationResult::LoweredFuncs(int idx) const {
if (idx >= lowered_funcs_.size()) {
std::stringstream ss;
ss << "The index(" << idx
<< ") is expected to be less than the size of group("
<< lowered_funcs_.size() << ").";
CINN_THROW(ss.str());
}
if (!lowered_funcs_[idx].has_value()) {
std::stringstream ss;
ss << "LoweredFuncs of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the lower process.\n"
<< Message();
CINN_THROW(ss.str());
}
return lowered_funcs_[idx].value();
}
std::vector<std::string> CompilationResult::SourceCodes() const {
std::vector<std::string> res(source_codes_.size());
for (int idx = 0; idx < source_codes_.size(); ++idx) {
if (source_codes_[idx].has_value()) {
res[idx] = source_codes_[idx].value();
} else {
std::stringstream ss;
ss << "Source Code of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the codegen "
"process.\n"
<< Message();
CINN_THROW(ss.str());
}
}
return res;
}
std::string CompilationResult::SourceCode(int idx) const {
if (idx >= source_codes_.size()) {
std::stringstream ss;
ss << "The index(" << idx
<< ") is expected to be less than the size of group("
<< lowered_funcs_.size() << ").";
CINN_THROW(ss.str());
}
if (!source_codes_[idx].has_value()) {
std::stringstream ss;
ss << "Source Code of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the codegen "
"process.\n"
<< Message();
CINN_THROW(ss.str());
}
return source_codes_[idx].value();
}
std::vector<std::string> CompilationResult::SourcePtxs() const {
std::vector<std::string> res(source_ptxs_.size());
for (int idx = 0; idx < source_ptxs_.size(); ++idx) {
if (source_ptxs_[idx].has_value()) {
res[idx] = source_ptxs_[idx].value();
} else {
std::stringstream ss;
ss << "Source PTX of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the nvrtc compile "
"process.\n"
<< Message();
CINN_THROW(ss.str());
}
}
return res;
}
std::string CompilationResult::SourcePtx(int idx) const {
if (idx >= source_ptxs_.size()) {
std::stringstream ss;
ss << "The index(" << idx
<< ") is expected to be less than the size of group("
<< lowered_funcs_.size() << ").";
CINN_THROW(ss.str());
}
if (!source_ptxs_[idx].has_value()) {
std::stringstream ss;
ss << "Source PTX of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the nvrtc compile "
"process.\n"
<< Message();
CINN_THROW(ss.str());
}
return source_ptxs_[idx].value();
}
const std::vector<std::unique_ptr<Instruction>>&
CompilationResult::RuntimeInstructions() const {
if (runtime_program_ != nullptr) {
return runtime_program_->GetRunInstructions();
}
for (int idx = 0; idx < instructions_.size(); ++idx) {
if (instructions_[idx] == nullptr) {
std::stringstream ss;
ss << "Instruction of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the build "
"instruction process.\n"
<< Message();
CINN_THROW(ss.str());
}
}
return instructions_;
}
const std::unique_ptr<Instruction>& CompilationResult::RuntimeInstruction(
int idx) const {
const std::vector<std::unique_ptr<Instruction>>& insts =
runtime_program_ ? runtime_program_->GetRunInstructions() : instructions_;
if (idx >= insts.size()) {
std::stringstream ss;
ss << "The index(" << idx
<< ") is expected to be less than the size of group(" << insts.size()
<< ").";
CINN_THROW(ss.str());
}
return insts[idx];
}
std::unique_ptr<Program> CompilationResult::RuntimeProgram() {
if (runtime_program_ == nullptr) {
std::stringstream ss;
ss << "Runtime program is not generated.\n"
<< "Some errors may have occurred during the compilation process.\n"
<< Message();
CINN_THROW(ss.str());
}
return std::move(runtime_program_);
}
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/program.h"
#include "paddle/cinn/ir/lowered_func.h"
namespace cinn {
namespace hlir {
namespace framework {
// An enum class used to control the compilation stage.
enum class CompilationStage {
// Fully compiled by default, the following compilation result can be
// obtained: lowered_function, source_code, source_ptx, instruction and
// runtime_program.
DEFAULT = 0,
// Just do lowering, we can only get lowered_function from compilation result.
LOWERING = 1,
// Stop after codegen and jit, we can get: lowered_function, source_code and
// source_ptx from compilation result.
CODEGEN_AND_JIT = 2,
// Stop after build instruction, we can get: lowered_function, source_code,
// source_ptx and runtime_program from compilation result.
BUILD_INSTRUCTION = 3,
};
// An enum class used to represent the compilation status.
enum class CompilationStatus {
// An unknown error occurred during compilation.
UNKNOWN_FAIL = 0,
// An error occurred during lowering.
LOWERING_FAIL = 1,
// An error occurred during codegen and jit.
CODEGEN_JIT_FAIL = 2,
// An error occurred during build instruction.
INSTUCTION_FAIL = 3,
// An error occurred during build runtime program.
PROGRAM_FAIL = 4,
// Compile successfully.
SUCCESS = 5,
};
struct CompilationContext {
CompilationContext() = default;
CompilationContext(const std::shared_ptr<Graph>& graph,
const std::shared_ptr<Scope>& scope,
const Target& target)
: graph(graph), scope(scope), target(target) {}
std::string attached_source_code = "";
// Compile options.
bool with_instantiate_variables = false;
bool with_buffer_handle_instruction_inserted = false;
bool remove_unused_variables = true;
// Compile stage, full compile by default.
CompilationStage stage = CompilationStage::DEFAULT;
// Compile target.
Target target;
// Computation graph.
std::shared_ptr<Graph> graph;
// Variable scope
std::shared_ptr<Scope> scope;
// Fetch var ids in cinn and the corresponding var nodes will not be fused
// so as to get the result.
std::unordered_set<std::string> fetch_var_ids;
// Map dst reuse var to the src var sharing buffer
absl::flat_hash_map<std::string, std::string> reuse_vars_map;
// Nodes group, it may come from the result of op fusion or graph tuning.
// Nodes in a group will be built into an Instruction.
std::vector<std::shared_ptr<Graph::Group>> groups;
// Corresponding lowered functions of above grouped nodes,
// if it is empty then graph_compiler will generate for them.
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
// CUDA stream.
void* stream = nullptr;
// Set attached source code, if code is not empty, these codes will replace
// the device_module code after SplitCudaAndHostModule.
void ApplySourceCode(const std::string& code);
// Apply results of auto-tune to compile.
// Compilation will start from CompilationStage::CODEGEN_AND_JIT when tuning
// results are applied.
void ApplyTuningResult(const auto_schedule::TuningResult& tuning_result);
};
class GraphCompiler;
class CompilationResult {
friend class GraphCompiler;
public:
void InitCompilationResult(int group_size);
// Setters
void SetStatus(int idx, const CompilationStatus& status);
void SetMessage(int idx, const std::string& message);
void SetLoweredFuncs(int idx, const std::vector<ir::LoweredFunc>& funcs);
void SetSourceCode(int idx, const std::string& source_code);
void SetSourcePtx(int idx, const std::string& source_ptx);
void SetInstruction(int idx, std::unique_ptr<Instruction> instruction);
void SetRuntimeProgram(std::unique_ptr<Program> runtime_program);
// Getters
bool IsSuccess() const;