Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
add_subdirectory(pir)
core_gather_headers()
gather_srcs(
......@@ -12,22 +13,21 @@ gather_srcs(
program.cc
parallel_compiler.cc
graph_compiler.cc
graph_compiler_util.cc
graph.cc
node.cc
pass.cc
op_strategy.cc
op_lowering.cc
op_lowering_util.cc
op_lowering_impl.cc
accuracy_checker.cc
visualize_helper.cc)
visualize_helper.cc
compile_error.cc)
# TODO(Aurelius84): new_ir_compiler depends on pd_dialect and could
# TODO(Aurelius84): pir_compiler depends on op_dialect_vjp and could
# not found under CINN_ONLY mode
if(NOT CINN_ONLY)
cinn_cc_library(new_ir_compiler SRCS new_ir_compiler.cc DEPS cinnapi
pd_dialect)
cinn_cc_library(convert_to_dialect SRCS convert_to_dialect.cc DEPS cinnapi
cinn_dialect)
cinn_cc_library(pir_compiler SRCS pir_compiler.cc DEPS cinnapi op_dialect_vjp)
endif()
if(WITH_CUDA)
......@@ -52,5 +52,5 @@ cinn_cc_test(test_hlir_framework_op SRCS op_test.cc DEPS cinncore)
cinn_cc_test(test_hlir_framework_print_graph_pass SRCS print_graph_pass_test.cc
DEPS cinncore)
cinn_cc_test(test_hlir_framework_graph SRCS graph_test.cc DEPS cinncore)
#cinn_cc_test(test_hlir_framework_graph_compiler SRCS graph_compiler_test.cc DEPS cinncore)
cinn_cc_test(test_hlir_framework_graph_compiler SRCS graph_compiler_test.cc
DEPS cinncore)
......@@ -21,7 +21,7 @@
#include <cuda_runtime.h>
#endif
DECLARE_int64(cinn_self_check_accuracy_num);
PD_DECLARE_int64(cinn_self_check_accuracy_num);
namespace cinn {
namespace hlir {
......
......@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
DECLARE_string(cinn_self_check_accuracy);
PD_DECLARE_string(cinn_self_check_accuracy);
namespace cinn {
namespace hlir {
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/utils/enum_string.h"
namespace cinn {
namespace hlir {
namespace framework {
std::string CompileErrorHandler::GeneralErrorMessage() const {
std::ostringstream os;
os << "[CompileError] An error occurred during compilation with the error "
"code: "
<< utils::Enum2String(status_) << std::endl;
os << "(at " << file_ << " : " << line_ << ")" << std::endl;
os << indent_str_ << "[Error info] " << this->err_msg_ << std::endl;
return os.str();
}
std::string CompileErrorHandler::DetailedErrorMessage() const {
std::ostringstream os;
os << GeneralErrorMessage();
os << indent_str_ << "[Detail info] " << detail_info_ << std::endl;
return os.str();
}
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/utils/error.h"
namespace cinn {
namespace hlir {
namespace framework {
/**
* This handler is used to deal with the errors during the compilation process
*/
class CompileErrorHandler : public utils::ErrorHandler {
public:
/**
* \brief constructor
* \param err_msg the error message
*/
explicit CompileErrorHandler(const CompilationStatus& status,
const std::string& err_msg,
const std::string& detail_info,
const char* file,
int line)
: status_(status),
err_msg_(err_msg),
detail_info_(detail_info),
file_(file),
line_(line) {}
/**
* \brief Returns a short error message corresponding to the kGeneral error
* level.
*/
std::string GeneralErrorMessage() const;
/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std::string DetailedErrorMessage() const;
CompilationStatus Status() const { return status_; }
private:
CompilationStatus status_;
std::string err_msg_;
std::string detail_info_;
const char* file_;
int line_;
};
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/convert_to_dialect.h"
#include <string>
#include <unordered_map>
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
#include "paddle/cinn/hlir/dialect/runtime_dialect.h"
#include "paddle/cinn/hlir/framework/program.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/program.h"
namespace cinn {
namespace hlir {
namespace framework {
std::unique_ptr<::ir::Program> ConvertToRuntimeDialect(
const hlir::framework::Program& program) {
::ir::IrContext* ctx = ::ir::IrContext::Instance();
ctx->GetOrRegisterDialect<cinn::dialect::RuntimeDialect>();
auto ir_program = std::make_unique<::ir::Program>(ctx);
std::string jit_op_name = dialect::JitKernelOp::name();
::ir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name);
auto& instrs = program.GetRunInstructions();
for (auto& instr : instrs) {
std::unordered_map<std::string, ::ir::Attribute> op_attrs{
{dialect::JitKernelOp::kAttrName,
::ir::PointerAttribute::get(ctx, instr.get())},
};
::ir::Operation* cinn_op =
::ir::Operation::Create({}, op_attrs, {}, op_info);
ir_program->block()->push_back(cinn_op);
}
return std::move(ir_program);
}
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
namespace ir {
class Program;
} // namespace ir
namespace cinn {
namespace hlir {
namespace framework {
class Program;
std::unique_ptr<::ir::Program> ConvertToRuntimeDialect(
const hlir::framework::Program& program);
} // namespace framework
} // namespace hlir
} // namespace cinn
......@@ -18,10 +18,14 @@
#include <sstream>
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/runtime/cuda/cuda_util.h"
#endif
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
DECLARE_string(cinn_fusion_groups_graphviz_dir);
PD_DECLARE_string(cinn_fusion_groups_graphviz_dir);
namespace cinn {
namespace hlir {
......@@ -309,15 +313,20 @@ void Graph::VisualizeGroupedGraph(
}
// Dump debug info for each group
LOG(INFO) << "Dump graph debug info to: "
<< FLAGS_cinn_fusion_groups_graphviz_dir;
VLOG(4) << "Dump graph debug info to: "
<< FLAGS_cinn_fusion_groups_graphviz_dir;
const auto& groups = RemoveAccCheckGroups(origin_groups);
const auto& group_dots = VisualizeGroups(groups, fetch_var_ids);
for (int idx = 0; idx < groups.size(); ++idx) {
// Create fusion_group_x folder
int device_id = 0;
#ifdef CINN_WITH_CUDA
cudaGetDevice(&device_id);
#endif
auto group_path =
utils::StringFormat("%s/fusion_group_%d",
utils::StringFormat("%s/device_%d/fusion_group_%d",
FLAGS_cinn_fusion_groups_graphviz_dir.c_str(),
device_id,
idx);
if (!MakeDirectory(group_path,
S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) {
......@@ -468,7 +477,7 @@ std::vector<std::string> Graph::VisualizeGroups(
return dot_vec;
}
std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() {
std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() const {
std::unordered_set<NodeData*> group_inputs;
// count all node's input data
......@@ -498,7 +507,7 @@ std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() {
return group_inputs;
}
std::unordered_set<NodeData*> Graph::Group::GetOutputNodeDatas() {
std::unordered_set<NodeData*> Graph::Group::GetOutputNodeDatas() const {
std::unordered_set<NodeData*> group_outputs;
for (auto node : this->output_nodes) {
......
......@@ -26,6 +26,7 @@
#include "paddle/cinn/hlir/framework/node.h"
namespace cinn {
namespace hlir {
namespace framework {
......@@ -59,6 +60,8 @@ class Graph : public cinn::common::Graph {
std::vector<std::vector<Node*>> groups;
struct Group {
Group() = default;
Group(const Group&) = delete;
Group(Group&&) = delete;
explicit Group(const Graph* graph) : graph_(graph) {}
......@@ -109,7 +112,7 @@ class Graph : public cinn::common::Graph {
}
};
std::vector<Node*> CollectNodes() {
std::vector<Node*> CollectNodes() const {
if (fused_sub_groups.size()) {
std::vector<Node*> tmp_nodes;
for (auto& group : fused_sub_groups) {
......@@ -144,8 +147,8 @@ class Graph : public cinn::common::Graph {
return node_set;
}
std::unordered_set<NodeData*> GetInputNodeDatas();
std::unordered_set<NodeData*> GetOutputNodeDatas();
std::unordered_set<NodeData*> GetInputNodeDatas() const;
std::unordered_set<NodeData*> GetOutputNodeDatas() const;
std::string GetFuncName() { return "fn_" + group_id + unique_id; }
......
......@@ -29,8 +29,11 @@
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/utils/enum_string.h"
#include "paddle/cinn/utils/profiler.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
namespace cinn {
namespace hlir {
namespace framework {
......@@ -40,90 +43,124 @@ using cinn::common::float16;
std::unique_ptr<Program> GraphCompiler::Build(const std::string& code) {
utils::RecordEvent("GraphCompiler::Build", utils::EventType::kGraph);
GraphCompiler::CompileOptions options;
options.attached_code = code;
options.with_instantiate_variables = true;
auto&& result = Build(options);
return std::move(result.runtime_program);
}
compilation_context_.ApplySourceCode(code);
compilation_context_.with_instantiate_variables = true;
void GraphCompiler::CompileOptions::Apply(
const auto_schedule::TuningResult& tuning_result) {
// assign options with TuningResult directly
groups.assign(tuning_result.subgraphs.begin(), tuning_result.subgraphs.end());
lowered_funcs.assign(tuning_result.function_groups.begin(),
tuning_result.function_groups.end());
auto&& result = Build(&compilation_context_);
return result.RuntimeProgram();
}
GraphCompiler::CompilationResult GraphCompiler::Build(
const GraphCompiler::CompileOptions& options,
std::unordered_set<std::string>&& fetch_var_ids,
void* stream) {
CompilationResult GraphCompiler::Build(CompilationContext* context) {
Context::Global().ResetNameId();
// write group's information into FLAGS_cinn_fusion_groups_graphviz_dir
graph_->VisualizeGroupedGraph(fetch_var_ids.empty() ? fetch_var_ids_
: fetch_var_ids);
context->graph->VisualizeGroupedGraph(context->fetch_var_ids);
if (options.with_instantiate_variables) {
InstantiateVariables();
if (context->with_instantiate_variables) {
InstantiateVariables(context);
}
VLOG(2) << "Compile With Parallel Compiler!";
utils::RecordEvent("GraphCompiler CompileResult",
utils::EventType::kOrdinary);
ParallelCompiler::CompileOptions option;
option.lowered_funcs = options.lowered_funcs;
parallel_compiler_ =
std::make_shared<ParallelCompiler>(scope_, graph_, option, target_);
auto result = (*parallel_compiler_.get())();
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
// Dump compilation result
backends::CompilationInfoDumper dumper(result);
if (context->stage != CompilationStage::DEFAULT || !result.IsSuccess()) {
return result;
}
if (options.remove_unused_variables) {
RemoveInvalidVariables(result.instructions);
if (context->remove_unused_variables) {
RemoveInvalidVariables(context, result.RuntimeInstructions());
}
if (options.with_buffer_handle_instruction_inserted) {
if (context->with_buffer_handle_instruction_inserted) {
VLOG(3) << "option.with_buffer_handle_instruction_inserted enable";
InsertBufferHandlers(&result.instructions);
InsertBufferHandlers(context, &result.instructions_);
}
VLOG(2) << "Compile With Parallel Compiler Done!";
GraphCompiler::CompilationResult compilation_result;
compilation_result.runtime_program.reset(
new Program(scope_, std::move(result.instructions)));
return compilation_result;
result.SetRuntimeProgram(std::make_unique<Program>(
context->scope, std::move(result.instructions_)));
return result;
}
CompilationResult GraphCompiler::Lowering() {
return Lowering(&compilation_context_);
}
void GraphCompiler::InstantiateVariables() {
CompilationResult GraphCompiler::Lowering(CompilationContext* context) {
// Global setting
Context::Global().ResetNameId();
// Setting compile options
VLOG(2) << "Compile With Parallel Compiler! But just lowering!";
context->stage = CompilationStage::LOWERING;
// Compile with parallel compiler
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
return result;
}
CompilationResult GraphCompiler::CodegenAndJit() {
return CodegenAndJit(&compilation_context_);
}
CompilationResult GraphCompiler::CodegenAndJit(CompilationContext* context) {
// Global setting
Context::Global().ResetNameId();
// Setting compile options
VLOG(2) << "Compile With Parallel Compiler! But just codegen and jit!";
context->stage = CompilationStage::CODEGEN_AND_JIT;
// Compile with parallel compiler
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
return result;
}
CompilationResult GraphCompiler::BuildInstruction() {
return BuildInstruction(&compilation_context_);
}
CompilationResult GraphCompiler::BuildInstruction(CompilationContext* context) {
// Global setting
Context::Global().ResetNameId();
// Setting compile options
VLOG(2) << "Compile With Parallel Compiler! But just build instruction!";
context->stage = CompilationStage::BUILD_INSTRUCTION;
// Compile with parallel compiler
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
return result;
}
void GraphCompiler::InstantiateVariables(CompilationContext* context) {
VLOG(3) << "Instantiate all variables on compile-time";
utils::RecordEvent("GraphCompiler MutableData", utils::EventType::kOrdinary);
// All variables reside in scope_, so traverse it to instantiate each one
for (auto& name : scope_->var_names()) {
auto* var = scope_->Var<Tensor>(std::string({name.data(), name.size()}));
for (auto& name : context->scope->var_names()) {
auto* var =
context->scope->Var<Tensor>(std::string({name.data(), name.size()}));
auto& tensor = absl::get<Tensor>(*var);
if (reuse_vars_map_.count(name)) {
auto src_var_name = reuse_vars_map_.at(name);
auto* src_var = scope_->Var<Tensor>(src_var_name);
if (context->reuse_vars_map.count(name)) {
auto src_var_name = context->reuse_vars_map.at(name);
auto* src_var = context->scope->Var<Tensor>(src_var_name);
auto& src_tensor = absl::get<Tensor>(*src_var);
tensor->set_buffer(src_tensor->get_buffer());
} else {
tensor->mutable_data(target_, tensor->type());
tensor->mutable_data(context->target, tensor->type());
}
}
}
void GraphCompiler::RemoveInvalidVariables(
CompilationContext* context,
const std::vector<std::unique_ptr<Instruction>>& instructions) {
// mark all variables are invalid initially
utils::RecordEvent("GraphCompiler RemoveInvalidVariables",
utils::EventType::kOrdinary);
std::unordered_set<std::string> invalid_variables;
auto var_names = scope_->var_names();
auto var_names = context->scope->var_names();
invalid_variables.reserve(var_names.size());
std::transform(
var_names.begin(),
......@@ -162,8 +199,8 @@ void GraphCompiler::RemoveInvalidVariables(
<< " invalid variables to be removed from scope";
std::for_each(invalid_variables.begin(),
invalid_variables.end(),
[this](const std::string& var_name) {
scope_->EraseVar(var_name);
[context](const std::string& var_name) {
context->scope->EraseVar(var_name);
VLOG(3) << "Variable(" << var_name << ") is erased";
});
}
......@@ -222,6 +259,7 @@ void GraphCompiler::AnalyzeVariableLifeTime(
}
void GraphCompiler::InsertBufferHandlers(
CompilationContext* context,
std::vector<std::unique_ptr<Instruction>>* instructions) {
utils::RecordEvent("GraphCompiler InsertBufferHandlers",
utils::EventType::kOrdinary);
......@@ -240,7 +278,7 @@ void GraphCompiler::InsertBufferHandlers(
auto function_name = "malloc_buffer_instruction_" + std::to_string(step);
auto malloc_instr =
std::make_unique<Instruction>(common::DefaultHostTarget(),
scope_.get(),
context->scope.get(),
malloc_var_names,
std::vector<std::string>({}),
function_name);
......@@ -263,7 +301,7 @@ void GraphCompiler::InsertBufferHandlers(
auto function_name = "free_buffer_instruction_" + std::to_string(step);
auto free_instr =
std::make_unique<Instruction>(common::DefaultHostTarget(),
scope_.get(),
context->scope.get(),
std::vector<std::string>({}),
free_var_names,
function_name);
......@@ -336,14 +374,17 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
poly::StageMap stages = C.back();
std::string func_name_prefix = "fn_";
auto funcs = lang::LowerVec(func_name_prefix + node_id,
stages,
all_arg_tensors,
{},
{},
nullptr,
target,
true);
ast_gen_ius::TensorGroup tensor_group =
ast_gen_ius::ConvertStageMapToTensorGroup(stages);
auto funcs = lang::LowerToAstVec(
func_name_prefix + node_id, all_arg_tensors, &tensor_group, target);
VLOG(4) << "Lower op: " << node_id << ", get " << funcs.size()
<< " LoweredFunc:\n";
for (auto fun : funcs) {
VLOG(4) << fun;
}
std::vector<common::CINNValue> schedule_inputs;
for (int i = 0; i < C.size() - 1; ++i) {
......@@ -390,7 +431,8 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
optim::OptimizeExprGPU(&(funcs_after_schedule[i]->body));
#endif
auto temp_buffers = lang::GetTempBuffers(
all_arg_tensors, stages, funcs_after_schedule[i]->body);
all_arg_tensors, tensor_group, funcs_after_schedule[i]->body);
funcs_after_schedule[i]->temp_bufs = temp_buffers;
funcs_after_schedule[i] =
ir::_LoweredFunc_::Make(funcs_after_schedule[i]->name,
......
......@@ -28,6 +28,7 @@
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/hlir/framework/parallel_compiler.h"
......@@ -46,48 +47,41 @@ namespace framework {
*/
class GraphCompiler final {
public:
GraphCompiler(Target target,
const std::shared_ptr<Scope>& scope,
const std::shared_ptr<Graph>& graph)
: target_(std::move(target)), scope_(scope), graph_(graph) {}
struct CompilationResult {
std::unique_ptr<Program> runtime_program;
};
struct CompileOptions {
std::string attached_code = "";
bool with_instantiate_variables = false;
bool with_buffer_handle_instruction_inserted = false;
bool remove_unused_variables = true;
// nodes group, it may come from the result of op fusion or graph tuning.
// nodes in a group will be built into an Instruction
std::vector<std::shared_ptr<Graph::Group>> groups;
// corresponding LoweredFuncs of above grouped nodes,
// if it is empty then graph_compiler will generate for them
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
// apply results of auto-tune to compile
void Apply(const auto_schedule::TuningResult& tuning_result);
};
GraphCompiler(CompilationContext context) : compilation_context_(context) {}
// Compile with a packing option and result, to be extended easily.
CompilationResult Build(const CompileOptions& options,
std::unordered_set<std::string>&& fetch_var_ids = {},
void* stream = nullptr);
CompilationResult Build(CompilationContext* context);
std::unique_ptr<Program> Build(const std::string& code = "");
const std::shared_ptr<Scope>& GetScope() const { return scope_; }
CompilationResult Lowering();
CompilationResult Lowering(CompilationContext* context);
CompilationResult CodegenAndJit();
CompilationResult CodegenAndJit(CompilationContext* context);
CompilationResult BuildInstruction();
CompilationResult BuildInstruction(CompilationContext* context);
const std::shared_ptr<Scope>& GetScope() const {
return compilation_context_.scope;
}
CompilationContext& GetCompilationContext() { return compilation_context_; }
void SetCompilationContext(const CompilationContext& context) {
compilation_context_ = context;
}
private:
// instantiate all variables on compile time
void InstantiateVariables();
void InstantiateVariables(CompilationContext* context);
// some variables are eliminated by optimized passes(such as OpFusion),
// we can filter out them according to arguments of the built instructions,
// and erase them from the scope to avoid unnecessary buffer allocation
void RemoveInvalidVariables(
CompilationContext* context,
const std::vector<std::unique_ptr<Instruction>>& instructions);
// find the first and last instruction where a variable used, and mark the
......@@ -102,21 +96,14 @@ class GraphCompiler final {
// firstly used in the next instruction, and insert a buffer free instruction
// applying on variables after no instruction will use them anymore
void InsertBufferHandlers(
CompilationContext* context,
std::vector<std::unique_ptr<Instruction>>* instructions);
private:
// parallel compiler
std::shared_ptr<ParallelCompiler> parallel_compiler_;
Target target_;
std::shared_ptr<Graph> graph_;
std::shared_ptr<Scope> scope_;
// fetch var ids in cinn and the corresponding var nodes will not be fused so
// as to get the result
std::unordered_set<std::string> fetch_var_ids_;
// map dst reuse var to the src var sharing buffer
absl::flat_hash_map<std::string, std::string> reuse_vars_map_;
CompilationContext compilation_context_;
CINN_DISALLOW_COPY_AND_ASSIGN(GraphCompiler);
};
......
......@@ -19,6 +19,7 @@
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/op/use_ops.h"
......@@ -48,7 +49,8 @@ TEST(GraphCompilerTest, TestRemoveInvaildVariables) {
ASSERT_EQ(scope->var_names().size(), 6);
EXPECT_NE(scope->FindVar(c->id), nullptr);
GraphCompiler gc(target, scope, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
auto runtime_program = gc.Build();
ASSERT_EQ(scope->var_names().size(), 3);
EXPECT_EQ(scope->FindVar(c->id), nullptr);
......@@ -69,10 +71,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
GraphCompiler gc_disable(target, scope, graph);
GraphCompiler::CompileOptions options;
CompilationContext context_disable(graph, scope, target);
GraphCompiler gc_disable(context_disable);
// disable with_buffer_handle_instruction_inserted: only 1 instruction
auto runtime_program_disable = gc_disable.Build(options).runtime_program;
auto runtime_program_disable =
gc_disable.Build(&context_disable).RuntimeProgram();
ASSERT_EQ(runtime_program_disable->size(), 1);
const auto& computation_instr_disable =
runtime_program_disable->GetRunInstructions().front();
......@@ -80,9 +83,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
// enable with_buffer_handle_instruction_inserted: 3 instructions, 1st ->
// malloc instruction(a, b, d), 2nd -> the real computation
// instruction(add + relu) and 3rd -> free instruction
GraphCompiler gc_enable(target, scope, graph);
options.with_buffer_handle_instruction_inserted = true;
auto runtime_program_enable = gc_enable.Build(options).runtime_program;
CompilationContext context_enable(graph, scope, target);
context_enable.with_buffer_handle_instruction_inserted = true;
GraphCompiler gc_enable(context_enable);
auto runtime_program_enable =
gc_enable.Build(&context_enable).RuntimeProgram();
const auto& instructions = runtime_program_enable->GetRunInstructions();
ASSERT_EQ(instructions.size(), 3);
......@@ -193,7 +198,8 @@ void RunCublas(
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
auto scope = BuildScope(target, graph);
GraphCompiler gc(target, scope, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
auto exe_program = gc.Build();
auto data_a = scope->GetTensor("A");
......@@ -231,6 +237,66 @@ TEST(GraphCompilerTest, TestCublas) {
RunCublas(64, 128, 128, true, true);
}
TEST(GraphCompilerTest, TestLowering) {
frontend::NetBuilder builder("test_lowering_on_graph_compiler");
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
auto target = common::DefaultNVGPUTarget();
auto program = builder.Build();
auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
CompilationResult result = gc.Lowering();
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}
TEST(GraphCompilerTest, TestCodegenAndJit) {
frontend::NetBuilder builder("test_codegen_and_jit_on_graph_compiler");
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
auto target = common::DefaultNVGPUTarget();
auto program = builder.Build();
auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
CompilationResult result = gc.CodegenAndJit();
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}
TEST(GraphCompilerTest, TestBuildInstruction) {
frontend::NetBuilder builder("test_build_instruction_on_graph_compiler");
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
auto target = common::DefaultNVGPUTarget();
auto program = builder.Build();
auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
CompilationResult result = gc.BuildInstruction();
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}
#endif
} // namespace framework
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/utils/error.h"
namespace cinn {
namespace hlir {
namespace framework {
void CompilationContext::ApplyTuningResult(
const auto_schedule::TuningResult& tuning_result) {
// assign options with TuningResult directly
groups.assign(tuning_result.subgraphs.begin(), tuning_result.subgraphs.end());
lowered_funcs.assign(tuning_result.function_groups.begin(),
tuning_result.function_groups.end());
}
void CompilationContext::ApplySourceCode(const std::string& code) {
attached_source_code = code;
}
void CompilationResult::InitCompilationResult(int group_size) {
size_ = group_size;
status_.resize(group_size, CompilationStatus::SUCCESS);
messages_.resize(group_size);
for (int idx = 0; idx < group_size; ++idx) {
messages_[idx] =
"Group Idx: " + std::to_string(idx) + ", Compile Success.\n";
}
lowered_funcs_.resize(group_size, std::nullopt);
source_codes_.resize(group_size, std::nullopt);
source_ptxs_.resize(group_size, std::nullopt);
instructions_.resize(group_size);
}
void CompilationResult::SetStatus(int idx, const CompilationStatus& status) {
if (idx < status_.size()) {
status_[idx] = status;
}
}
void CompilationResult::SetMessage(int idx, const std::string& message) {
if (idx < messages_.size()) {
messages_[idx] = message;
}
}
void CompilationResult::SetLoweredFuncs(
int idx, const std::vector<ir::LoweredFunc>& funcs) {
if (idx < lowered_funcs_.size()) {
lowered_funcs_[idx] = funcs;
}
}
void CompilationResult::SetSourceCode(int idx, const std::string& source_code) {
if (idx < source_codes_.size()) {
source_codes_[idx] = source_code;
}
}
void CompilationResult::SetSourcePtx(int idx, const std::string& source_ptx) {
if (idx < source_ptxs_.size()) {
source_ptxs_[idx] = source_ptx;
}
}
void CompilationResult::SetInstruction(
int idx, std::unique_ptr<Instruction> instruction) {
if (idx < instructions_.size()) {
instructions_[idx] = std::move(instruction);
}
}
void CompilationResult::SetRuntimeProgram(
std::unique_ptr<Program> runtime_program) {
runtime_program_ = std::move(runtime_program);
}
bool CompilationResult::IsSuccess() const {
for (const CompilationStatus& s : status_) {
if (s != CompilationStatus::SUCCESS) {
return false;
}
}
return true;
}
CompilationStatus CompilationResult::Status() const {
CompilationStatus worst_status = CompilationStatus::SUCCESS;
for (const CompilationStatus& s : status_) {
if (s < worst_status) {
worst_status = s;
}
}
return worst_status;
}
CompilationStatus CompilationResult::Status(int idx) const {
if (idx >= status_.size()) {
return CompilationStatus::UNKNOWN_FAIL;
}
return status_[idx];
}
std::string CompilationResult::Message() const {
std::string res;
for (int idx = 0; idx < messages_.size(); ++idx) {
res += messages_[idx];
}
return res;
}
std::string CompilationResult::Message(int idx) const {
if (idx >= messages_.size()) {
std::stringstream ss;
ss << "The index(" << idx
<< ") is expected to be less than the size of group("
<< lowered_funcs_.size() << ").";
CINN_THROW(ss.str());
}
return messages_[idx];
}
std::vector<std::vector<ir::LoweredFunc>> CompilationResult::LoweredFuncs()
const {
std::vector<std::vector<ir::LoweredFunc>> res(lowered_funcs_.size());
for (int idx = 0; idx < lowered_funcs_.size(); ++idx) {
if (lowered_funcs_[idx].has_value()) {
res[idx] = lowered_funcs_[idx].value();
} else {
std::stringstream ss;
ss << "LoweredFuncs of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the lower "
"process.\n"
<< Message();
CINN_THROW(ss.str());
}
}
return res;
}
std::vector<ir::LoweredFunc> CompilationResult::LoweredFuncs(int idx) const {
if (idx >= lowered_funcs_.size()) {
std::stringstream ss;
ss << "The index(" << idx
<< ") is expected to be less than the size of group("
<< lowered_funcs_.size() << ").";
CINN_THROW(ss.str());
}
if (!lowered_funcs_[idx].has_value()) {
std::stringstream ss;
ss << "LoweredFuncs of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the lower process.\n"
<< Message();
CINN_THROW(ss.str());
}
return lowered_funcs_[idx].value();
}
std::vector<std::string> CompilationResult::SourceCodes() const {
std::vector<std::string> res(source_codes_.size());
for (int idx = 0; idx < source_codes_.size(); ++idx) {
if (source_codes_[idx].has_value()) {
res[idx] = source_codes_[idx].value();
} else {
std::stringstream ss;
ss << "Source Code of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the codegen "
"process.\n"
<< Message();
CINN_THROW(ss.str());
}
}
return res;
}
std::string CompilationResult::SourceCode(int idx) const {
if (idx >= source_codes_.size()) {
std::stringstream ss;
ss << "The index(" << idx
<< ") is expected to be less than the size of group("
<< lowered_funcs_.size() << ").";
CINN_THROW(ss.str());
}
if (!source_codes_[idx].has_value()) {
std::stringstream ss;
ss << "Source Code of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the codegen "
"process.\n"
<< Message();
CINN_THROW(ss.str());
}
return source_codes_[idx].value();
}
std::vector<std::string> CompilationResult::SourcePtxs() const {
std::vector<std::string> res(source_ptxs_.size());
for (int idx = 0; idx < source_ptxs_.size(); ++idx) {
if (source_ptxs_[idx].has_value()) {
res[idx] = source_ptxs_[idx].value();
} else {
std::stringstream ss;
ss << "Source PTX of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the nvrtc compile "
"process.\n"
<< Message();
CINN_THROW(ss.str());
}
}
return res;
}
std::string CompilationResult::SourcePtx(int idx) const {
if (idx >= source_ptxs_.size()) {
std::stringstream ss;
ss << "The index(" << idx
<< ") is expected to be less than the size of group("
<< lowered_funcs_.size() << ").";
CINN_THROW(ss.str());
}
if (!source_ptxs_[idx].has_value()) {
std::stringstream ss;
ss << "Source PTX of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the nvrtc compile "
"process.\n"
<< Message();
CINN_THROW(ss.str());
}
return source_ptxs_[idx].value();
}
const std::vector<std::unique_ptr<Instruction>>&
CompilationResult::RuntimeInstructions() const {
if (runtime_program_ != nullptr) {
return runtime_program_->GetRunInstructions();
}
for (int idx = 0; idx < instructions_.size(); ++idx) {
if (instructions_[idx] == nullptr) {
std::stringstream ss;
ss << "Instruction of group[" << idx << "] is not generated.\n"
<< "Some errors may have occurred during or before the build "
"instruction process.\n"
<< Message();
CINN_THROW(ss.str());
}
}
return instructions_;
}
const std::unique_ptr<Instruction>& CompilationResult::RuntimeInstruction(
int idx) const {
const std::vector<std::unique_ptr<Instruction>>& insts =
runtime_program_ ? runtime_program_->GetRunInstructions() : instructions_;
if (idx >= insts.size()) {
std::stringstream ss;
ss << "The index(" << idx
<< ") is expected to be less than the size of group(" << insts.size()
<< ").";
CINN_THROW(ss.str());
}
return insts[idx];
}
std::unique_ptr<Program> CompilationResult::RuntimeProgram() {
if (runtime_program_ == nullptr) {
std::stringstream ss;
ss << "Runtime program is not generated.\n"
<< "Some errors may have occurred during the compilation process.\n"
<< Message();
CINN_THROW(ss.str());
}
return std::move(runtime_program_);
}
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/program.h"
#include "paddle/cinn/ir/lowered_func.h"
namespace cinn {
namespace hlir {
namespace framework {
// An enum class used to control the compilation stage.
enum class CompilationStage {
// Fully compiled by default, the following compilation result can be
// obtained: lowered_function, source_code, source_ptx, instruction and
// runtime_program.
DEFAULT = 0,
// Just do lowering, we can only get lowered_function from compilation result.
LOWERING = 1,
// Stop after codegen and jit, we can get: lowered_function, source_code and
// source_ptx from compilation result.
CODEGEN_AND_JIT = 2,
// Stop after build instruction, we can get: lowered_function, source_code,
// source_ptx and runtime_program from compilation result.
BUILD_INSTRUCTION = 3,
};
// An enum class used to represent the compilation status.
enum class CompilationStatus {
// An unknown error occurred during compilation.
UNKNOWN_FAIL = 0,
// An error occurred during lowering.
LOWERING_FAIL = 1,
// An error occurred during codegen and jit.
CODEGEN_JIT_FAIL = 2,
// An error occurred during build instruction.
INSTUCTION_FAIL = 3,
// An error occurred during build runtime program.
PROGRAM_FAIL = 4,
// Compile successfully.
SUCCESS = 5,
};
struct CompilationContext {
CompilationContext() = default;
CompilationContext(const std::shared_ptr<Graph>& graph,
const std::shared_ptr<Scope>& scope,
const Target& target)
: graph(graph), scope(scope), target(target) {}
std::string attached_source_code = "";
// Compile options.
bool with_instantiate_variables = false;
bool with_buffer_handle_instruction_inserted = false;
bool remove_unused_variables = true;
// Compile stage, full compile by default.
CompilationStage stage = CompilationStage::DEFAULT;
// Compile target.
Target target;
// Computation graph.
std::shared_ptr<Graph> graph;
// Variable scope
std::shared_ptr<Scope> scope;
// Fetch var ids in cinn and the corresponding var nodes will not be fused
// so as to get the result.
std::unordered_set<std::string> fetch_var_ids;
// Map dst reuse var to the src var sharing buffer
absl::flat_hash_map<std::string, std::string> reuse_vars_map;
// Nodes group, it may come from the result of op fusion or graph tuning.
// Nodes in a group will be built into an Instruction.
std::vector<std::shared_ptr<Graph::Group>> groups;
// Corresponding lowered functions of above grouped nodes,
// if it is empty then graph_compiler will generate for them.
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
// CUDA stream.
void* stream = nullptr;
// Set attached source code, if code is not empty, these codes will replace
// the device_module code after SplitCudaAndHostModule.
void ApplySourceCode(const std::string& code);
// Apply results of auto-tune to compile.
// Compilation will start from CompilationStage::CODEGEN_AND_JIT when tuning
// results are applied.
void ApplyTuningResult(const auto_schedule::TuningResult& tuning_result);
};
class GraphCompiler;
class CompilationResult {
friend class GraphCompiler;
public:
void InitCompilationResult(int group_size);
// Setters
void SetStatus(int idx, const CompilationStatus& status);
void SetMessage(int idx, const std::string& message);
void SetLoweredFuncs(int idx, const std::vector<ir::LoweredFunc>& funcs);
void SetSourceCode(int idx, const std::string& source_code);
void SetSourcePtx(int idx, const std::string& source_ptx);
void SetInstruction(int idx, std::unique_ptr<Instruction> instruction);
void SetRuntimeProgram(std::unique_ptr<Program> runtime_program);
// Getters
bool IsSuccess() const;
int Size() const { return size_; }
CompilationStatus Status() const;
CompilationStatus Status(int idx) const;
std::string Message() const;
std::string Message(int idx) const;
std::vector<std::vector<ir::LoweredFunc>> LoweredFuncs() const;
std::vector<ir::LoweredFunc> LoweredFuncs(int idx) const;
std::vector<std::string> SourceCodes() const;
std::string SourceCode(int idx) const;
std::vector<std::string> SourcePtxs() const;
std::string SourcePtx(int idx) const;
const std::vector<std::unique_ptr<Instruction>>& RuntimeInstructions() const;
const std::unique_ptr<Instruction>& RuntimeInstruction(int idx) const;
std::unique_ptr<Program> RuntimeProgram();
private:
std::vector<CompilationStatus> status_;
std::vector<std::string> messages_;
std::vector<std::optional<std::vector<ir::LoweredFunc>>> lowered_funcs_;
std::vector<std::optional<std::string>> source_codes_;
std::vector<std::optional<std::string>> source_ptxs_;
std::vector<std::unique_ptr<Instruction>> instructions_;
std::unique_ptr<Program> runtime_program_;
int size_;
};
} // namespace framework
} // namespace hlir
} // namespace cinn
......@@ -20,7 +20,7 @@
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
DECLARE_string(cinn_fusion_groups_graphviz_dir);
PD_DECLARE_string(cinn_fusion_groups_graphviz_dir);
namespace cinn {
namespace hlir {
......
......@@ -22,8 +22,8 @@
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/profiler.h"
DECLARE_bool(cinn_sync_run);
DECLARE_string(cinn_self_check_accuracy);
PD_DECLARE_bool(cinn_sync_run);
PD_DECLARE_string(cinn_self_check_accuracy);
namespace cinn {
namespace hlir {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/new_ir_compiler.h"
#include <absl/types/variant.h>
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/utils/attribute_util.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/ir/core/builtin_type.h"
namespace cinn {
namespace hlir {
namespace framework {
const std::unordered_map<std::string, std::string> CompatibleInfo::OP_NAMES = {
{"pd.full", "fill_constant"}, {"pd.matmul", "matmul"}};
// TODO(Aurelius84): Need abstract this logic to implement Proxy for
// the co-existance with GraphCompiler.
std::unique_ptr<Program> NewIRCompiler::Build() {
m_builder_.Clear();
// NOTE(Aurelius84): Currently only support each op for one group
std::vector<std::vector<::ir::Operation*>> groups;
for (auto it = program_.block()->begin(); it != program_.block()->end();
++it) {
groups.push_back({*it});
}
VLOG(4) << "Groups size: " << groups.size();
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
for (int i = 0; i < groups.size(); ++i) {
lowered_funcs.emplace_back(GetOpFunc(*groups[i][0], i));
}
for (auto&& lowered_func : lowered_funcs) {
ProcessFunction(lowered_func);
}
compiler_ = backends::Compiler::Create(target_);
auto build_module = m_builder_.Build();
compiler_->Build(build_module, "");
auto instructions = BuildInstructions(groups);
// TODO(Aurelius84): Instantiate all tensors on compile-time, which is
// controlled by 'options.with_instantiate_variables' in GraphCompiler.
// Moreover, it's better to implement InsertBufferHandlers() logic
// to automatically insert Malloc and Free instructions.
for (auto& name : scope_->var_names()) {
std::string var_name({name.data(), name.size()});
VLOG(4) << "Instantiate " << var_name << " on compile-time";
auto* var = scope_->Var<Tensor>(var_name);
auto& tensor = absl::get<Tensor>(*var);
tensor->mutable_data(target_, tensor->type());
}
return std::make_unique<Program>(scope_, std::move(instructions));
}
std::vector<ir::LoweredFunc> NewIRCompiler::GetOpFunc(const ::ir::Operation& op,
int idx) {
std::vector<ir::Tensor> inputs;
std::vector<common::CINNValue> cinn_inputs;
auto op_name = op.name();
VLOG(4) << "GetOpFunc for op: " << op_name;
// step 1: Deal with Oprands
for (int i = 0; i < op.num_operands(); ++i) {
auto in_value = op.operand_source(i);
// TODO(Aurelius84): For now, use addr as name but it's not wise.
std::string input_id = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(in_value));
auto type_info =
in_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto in_shape = phi::vectorize<int>(type_info.dims());
auto dtype = type_info.dtype();
ir::Tensor temp = lang::CreatePlaceHolder(
in_shape, utils::ConvertIRType(dtype), input_id);
inputs.push_back(temp);
cinn_inputs.push_back(common::CINNValue(temp));
}
for (auto out_name : OpGetOutputNames(op)) {
cinn_inputs.push_back(common::CINNValue(out_name));
}
VLOG(4) << "inputs.size(): " << inputs.size();
// step 2: Deal with OpResult
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
for (int i = 0; i < op.num_results(); ++i) {
auto out_value = op.result(i);
auto type_info =
out_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
out_types.push_back(utils::ConvertIRType(type_info.dtype()));
auto out_shape = phi::vectorize<int>(type_info.dims());
out_shapes.push_back(std::move(out_shape));
}
VLOG(4) << "out_types.size(): " << out_types.size();
NodeAttr node_attrs;
{
VLOG(4) << "op.attributes():" << op.attributes().size();
auto attrs = utils::ConvertAttributes(op.attributes());
node_attrs.node_name = CompatibleInfo::OP_NAMES.at(op_name);
node_attrs.attr_store = std::move(attrs);
}
auto& strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
// NOTE(Aurelius84): Do we need replace all hlir::framework Operator with
// ::ir::Program ?
const hlir::framework::Operator* cinn_op =
Operator::Get(CompatibleInfo::OP_NAMES.at(op_name));
auto impl = OpStrategy::SelectImpl(
strategy[cinn_op](node_attrs, inputs, out_types, out_shapes, target_));
common::CINNValuePack C = impl->fcompute(common::CINNValuePack{cinn_inputs});
poly::StageMap stages = C.back();
// make sure all the tensors in the stages before schedule launch.
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
stages->InsertLazily(temp.as_tensor_ref());
}
C = impl->fschedule(C);
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
// checkout whether the tensor is with buffer.
if ((!temp.as_tensor_ref()->buffer.defined() ||
this->target_ != common::DefaultNVGPUTarget()) &&
!stages[temp.as_tensor_ref()]->inlined()) {
inputs.push_back(temp.as_tensor_ref());
}
}
auto func = lang::LowerVec(
GenOpFuncName(op, idx), stages, inputs, {}, {}, nullptr, target_);
return func;
}
void NewIRCompiler::ProcessFunction(
const std::vector<ir::LoweredFunc>& lowered_funcs) {
for (auto&& func : lowered_funcs) {
for (auto&& arg : func->args) {
std::string arg_name = arg.name();
if (arg_name[0] == '_') arg_name = arg_name.substr(1);
auto* var = scope_->FindVar(arg_name);
// For argument buffer not in scope, create it.
if (!var && arg.is_buffer()) {
auto* new_var = scope_->Var<Tensor>(arg_name);
auto& tensor = absl::get<Tensor>(*new_var);
std::vector<Shape::dim_t> shape;
for (auto& shape_dim : arg.buffer_arg()->shape) {
CHECK(shape_dim.is_constant());
shape.push_back(static_cast<int>(shape_dim.get_constant()));
}
tensor->Resize(Shape{shape});
tensor->set_type(arg.buffer_arg()->dtype);
}
}
m_builder_.AddFunction(func);
}
}
std::vector<std::unique_ptr<Instruction>> NewIRCompiler::BuildInstructions(
const std::vector<std::vector<::ir::Operation*>>& groups) {
std::vector<std::unique_ptr<Instruction>> instructions;
for (int idx = 0; idx < groups.size(); ++idx) {
// TODO(Aurelius84): only support single op in groups
auto& op = *groups[idx][0];
auto instr_name = op.name();
auto instr =
std::unique_ptr<Instruction>(new Instruction(target_,
scope_.get(),
OpGetInputNames(op),
OpGetOutputNames(op),
instr_name));
auto& op_func_name = GenOpFuncName(op, idx);
auto* fn_ptr = compiler_->Lookup(op_func_name);
CHECK(fn_ptr);
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), op_func_name);
// As some instruction like reduce, will generate more than one kernel.
// So try to find the rest kernel, if it exists.
// SetSubKernels(instr.get(), op_func_name);
instr->Finalize();
instructions.push_back(std::move(instr));
}
return instructions;
}
const std::string& NewIRCompiler::GenOpFuncName(const ::ir::Operation& op,
int idx) {
// TODO(Aurelius84): . will raise compiler error in pd.xxx, need more
// elegant way to generate function name.
std::string op_name = op.name().substr(3) + "_" + std::to_string(idx);
std::string func_name = Context::Global().NewName("fn_" + op_name);
func_names_.try_emplace(op_name, func_name);
return func_names_.at(op_name);
}
std::vector<std::string> NewIRCompiler::OpGetInputNames(
const ::ir::Operation& op) {
std::vector<std::string> names;
std::unordered_set<std::string> repeat;
for (int i = 0; i < op.num_operands(); ++i) {
auto value = op.operand_source(i);
std::string name = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(value));
if (repeat.count(name)) {
continue;
}
repeat.insert(name);
names.push_back(name);
}
return names;
}
std::vector<std::string> NewIRCompiler::OpGetOutputNames(
const ::ir::Operation& op) {
std::vector<std::string> names;
for (int i = 0; i < op.num_results(); ++i) {
auto value = op.result(i);
std::string name = CompatibleInfo::kOutputPrefix +
std::to_string(std::hash<::ir::Value>()(value));
names.push_back(std::move(name));
}
return names;
}
std::shared_ptr<Scope> BuildScope(const Target& target,
const ::ir::Program& program) {
std::unordered_set<::ir::Value> visited;
auto scope = std::make_shared<Scope>();
auto create_var = [&](const std::string& name_prefix, ::ir::Value value) {
if (visited.count(value) > 0) return;
visited.emplace(value);
std::string name =
name_prefix + std::to_string(std::hash<::ir::Value>()(value));
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto* var = scope->Var<Tensor>(name);
auto& tensor = absl::get<Tensor>(*var);
// NOTE: can be replaced with phi::vectorized ?
std::vector<Shape::dim_t> shape;
for (auto i = 0; i < type_info.dims().size(); ++i) {
shape.push_back(Shape::dim_t(type_info.dims()[i]));
}
tensor->Resize(Shape{shape});
tensor->set_type(utils::ConvertIRType(type_info.dtype()));
};
for (auto it = program.block()->begin(); it != program.block()->end(); ++it) {
for (auto i = 0; i < (*it)->num_operands(); ++i) {
auto in_value = (*it)->operand_source(i);
create_var(CompatibleInfo::kInputPrefix, in_value);
}
for (auto i = 0; i < (*it)->num_results(); ++i) {
auto out_value = (*it)->result(i);
create_var(CompatibleInfo::kOutputPrefix, out_value);
}
}
return scope;
}
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include "paddle/cinn/common/macros.h"
#include "paddle/ir/core/program.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
namespace cinn {
namespace hlir {
namespace framework {
struct CompatibleInfo {
static constexpr char* kInputPrefix = "input_";
static constexpr char* kOutputPrefix = "output_";
// TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP
// macros or attempt to unify Op name with Paddle and CINN.
static const std::unordered_map<std::string, std::string> OP_NAMES;
};
// TODO(Aurelius84): Need abstract this logic to implement Proxy for
// the co-existance with GraphCompiler.
class NewIRCompiler final {
public:
NewIRCompiler(const ::ir::Program& prog,
const Target& target,
const std::shared_ptr<Scope>& scope)
: program_(prog),
m_builder_("NewIR", target),
target_(target),
scope_(scope) {}
std::unique_ptr<Program> Build();
std::vector<ir::LoweredFunc> GetOpFunc(const ::ir::Operation& op, int idx);
void ProcessFunction(const std::vector<ir::LoweredFunc>& lowered_funcs);
std::vector<std::unique_ptr<Instruction>> BuildInstructions(
const std::vector<std::vector<::ir::Operation*>>& groups);
protected:
const std::string& GenOpFuncName(const ::ir::Operation& op, int idx);
std::vector<std::string> OpGetInputNames(const ::ir::Operation& op);
std::vector<std::string> OpGetOutputNames(const ::ir::Operation& op);
private:
CINN_DISALLOW_COPY_AND_ASSIGN(NewIRCompiler);
const ::ir::Program& program_;
ir::Module::Builder m_builder_;
std::unique_ptr<backends::Compiler> compiler_{nullptr};
Target target_;
std::shared_ptr<Scope> scope_;
std::unordered_map<std::string, std::string> func_names_;
};
std::shared_ptr<Scope> BuildScope(const Target&, const ::ir::Program&);
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/hlir/framework/op_lowering_util.h"
#include "paddle/cinn/hlir/op/external_api_registry.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
DECLARE_bool(cinn_use_cuda_vectorize);
namespace cinn {
namespace hlir {
namespace framework {
using common::bfloat16;
using common::float16;
using framework::Node;
using framework::NodeData;
using framework::OpPatternKind;
using framework::shape_t;
using framework::StrategyFunction;
using common::Type;
using cinn::hlir::op::ExternalApiRegistry;
OpLowerer::OpLowerer(
const absl::flat_hash_map<std::string, Type>& type_dict,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const Target& target)
: type_dict_(type_dict), shape_dict_(shape_dict), target_(target) {}
std::vector<ir::LoweredFunc> OpLowerer::Lower(const GroupPtr& group,
bool apply_op_schedule,
bool apply_group_schedule) {
VLOG(3) << "Lowering Group : " << group->group_id
<< " , Op Pattern : " << group->op_pattern_kind;
group->input_names.clear();
group->output_names.clear();
switch (group->op_pattern_kind) {
case framework::kElementWise:
case framework::kBroadcast:
case framework::kInjective:
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::ElementwiseScheduleDetermineFunction);
case framework::kReduction:
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::ReduceScheduleDetermineFunction);
case framework::kOutFusible:
LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!";
case framework::kNonFusible:
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::NonFusibleScheduleDetermineFunction);
default:
LOG(FATAL) << "Group Pattern Kind Is Unknown!";
}
}
bool OpLowerer::ElementwiseScheduleDetermineFunction(Node* node) {
return true;
}
bool OpLowerer::ReduceScheduleDetermineFunction(Node* node) {
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
return op_pattern_dict[node->op()] == framework::kReduction;
}
bool OpLowerer::NonFusibleScheduleDetermineFunction(Node* node) { return true; }
std::vector<ir::LoweredFunc> OpLowerer::LowerGroup(
const GroupPtr& group,
bool apply_op_schedule,
bool apply_group_schedule,
ScheduleDetermineFunction schedule_determine_func) {
// 1.Do compute, lower and schedule for each op.
VLOG(3) << "group->fused_sub_groups.size() is : "
<< group->fused_sub_groups.size();
std::vector<Node*> nodes = group->CollectNodes();
if (nodes.size() == 1 && nodes[0]->op()->name == "custom_call") {
return LowerCustomCall(group);
}
std::vector<ir::Tensor> group_func_arg_tensors;
std::unordered_map<std::string, ir::Tensor> tensor_map;
bool do_op_schedule = apply_group_schedule || apply_op_schedule;
std::vector<ir::Expr> func_bodies = LowerOps(nodes,
do_op_schedule,
schedule_determine_func,
&group_func_arg_tensors,
&tensor_map);
// 2.Do group schedule.
ir::ModuleExpr mod_expr(func_bodies);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
if (apply_group_schedule) {
DoGroupSchedule(ir_sch, group, tensor_map);
VLOG(3) << "After group schedule, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
}
// 3.Do post-processing,
// including preparing function args and temporary variables,
// applying low-level optimization passes, etc.
return PostProcess(
group, tensor_map, do_op_schedule, &ir_sch, &group_func_arg_tensors);
}
std::vector<ir::LoweredFunc> OpLowerer::LowerCustomCall(const GroupPtr& group) {
std::vector<Node*> nodes = group->CollectNodes();
CHECK_EQ(nodes.size(), 1);
Node* node = nodes[0];
std::vector<ir::Tensor> op_func_arg_tensors;
std::unordered_map<std::string, ir::Tensor> tensor_map;
for (auto& node_data : GetInputNodeData(node)) {
CHECK(node_data);
ir::Tensor tensor;
if (!tensor_map.count(node_data->id())) {
tensor = GetTensor(node_data, this->type_dict_, this->shape_dict_);
// record tensor.
tensor_map[node_data->id()] = tensor;
// input name.
group->input_names.push_back(node_data->id());
} else {
tensor = tensor_map[node_data->id()];
}
op_func_arg_tensors.push_back(tensor);
}
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
auto node_datas = GetAllNodeData(node);
for (auto node_data : node_datas) {
group->output_names.push_back(node_data->id());
out_types.push_back(this->type_dict_.at(node_data->id()));
out_shapes.push_back(this->shape_dict_.at(node_data->id()));
}
auto& cinn_strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()](
node->attrs, op_func_arg_tensors, out_types, out_shapes, target_));
std::string external_api;
if (node->attrs.attr_store.count("custom_call")) {
external_api =
absl::get<std::string>(node->attrs.attr_store.at("custom_call"));
} else {
external_api = ExternalApiRegistry::Global()->GetExternalApi(node, target_);
}
std::vector<common::CINNValue> compute_args = {
common::CINNValue(group->GetFuncName()), common::CINNValue(external_api)};
common::CINNValuePack pack =
impl->fcompute(common::CINNValuePack{compute_args});
CHECK_EQ(pack.size(), 1UL);
// reset input names as extern api input args can't be remove duplicate.
group->input_names.clear();
for (auto& inode : node->inlinks_in_order()) {
group->input_names.push_back(inode->source()->as<NodeData>()->id());
}
return {pack[0].operator ir::Expr().as_lowered_func_ref()};
}
std::vector<ir::LoweredFunc> OpLowerer::PostProcess(
const GroupPtr& group,
const std::unordered_map<std::string, ir::Tensor>& tensor_map,
bool done_op_schedule,
ir::IRSchedule* ir_sch,
std::vector<ir::Tensor>* group_func_arg_tensors) {
// 1.Prepare function args
group->input_names.clear();
std::vector<ir::Argument> group_func_args;
std::unordered_set<std::string> arg_name_set;
for (auto& arg_tensor : *group_func_arg_tensors) {
// input node data name.
group->input_names.push_back(arg_tensor->name);
// input args
group_func_args.emplace_back(arg_tensor->buffer, ir::Argument::IO::kInput);
arg_name_set.insert(arg_tensor->buffer->name);
}
group->output_names.clear();
for (auto& node : group->output_nodes) {
// collect all output tensor.
for (auto node_data : GetAllNodeData(node)) {
std::string output_node_data_name = node_data->id();
group->output_names.push_back(output_node_data_name);
// CHECK(tensor_map.count(output_node_data_name)) << "Can't find output
// tensor " << output_node_data_name;
if (tensor_map.count(output_node_data_name) == 0) {
continue;
}
auto tensor = tensor_map.at(output_node_data_name);
if (arg_name_set.count(tensor->buffer->name) != 0) {
continue;
}
// output arg tensors
group_func_arg_tensors->push_back(tensor);
// output args
group_func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput);
arg_name_set.insert(tensor->buffer->name);
}
}
if (!done_op_schedule) {
std::unordered_set<std::string> args_set;
for (auto arg : group_func_args) {
args_set.insert(arg.name());
}
for (auto& tensor_pair : tensor_map) {
if (args_set.count("_" + tensor_pair.second->name)) {
continue;
}
group_func_arg_tensors->push_back(tensor_pair.second);
// use the underlying tensor name to be consistent with the argument name
// in the lowered function
group->output_names.push_back(tensor_pair.second->name);
group_func_args.emplace_back(tensor_pair.second->buffer,
ir::Argument::IO::kOutput);
}
}
auto func_body = ir_sch->GetModule().GetExprs().at(0);
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(func_body));
#endif
// 2.Prepare temp buffers
poly::StageMap stages;
auto temp_buffers =
lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body);
// 3.Building LoweredFunc
auto func = ir::_LoweredFunc_::Make(group->GetFuncName(),
group_func_args,
ir_sch->GetModule().GetExprs().at(0),
temp_buffers);
if (!done_op_schedule) {
func->PrepareBufferCastExprs();
}
// 4.Apply low level pass
func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref();
return {func};
}
std::vector<ir::Expr> OpLowerer::LowerOps(
const std::vector<Node*>& nodes,
bool apply_op_schedule,
ScheduleDetermineFunction schedule_determine_func,
std::vector<ir::Tensor>* group_func_arg_tensors,
std::unordered_map<std::string, ir::Tensor>* tensor_map) {
auto& strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
std::vector<Expr> func_bodies;
for (Node* node : nodes) {
// 1.Select Op impl
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
std::vector<NodeData*> node_datas = GetAllNodeData(node);
for (const auto& node_data : node_datas) {
out_types.push_back(this->type_dict_.at(node_data->id()));
out_shapes.push_back(this->shape_dict_.at(node_data->id()));
}
std::vector<ir::Tensor> op_func_arg_tensors =
std::move(CollectInputTensor(node,
this->type_dict_,
this->shape_dict_,
group_func_arg_tensors,
tensor_map));
auto op_impl =
OpStrategy::SelectImpl(strategy[node->op()](node->attrs,
op_func_arg_tensors,
out_types,
out_shapes,
this->target_));
// 2.Perform the lower process of Op
std::vector<ir::LoweredFunc> funcs =
DoOpLower(op_impl, node, tensor_map, &op_func_arg_tensors);
if (apply_op_schedule && (this->*schedule_determine_func)(node)) {
// 3.Perform the schedule of Op
func_bodies.push_back(DoOpSchedule(op_impl, op_func_arg_tensors, funcs));
} else {
for (const ir::LoweredFunc& func : funcs) {
func_bodies.push_back(func->body);
}
}
}
return func_bodies;
}
std::vector<ir::LoweredFunc> OpLowerer::DoOpLower(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
Node* node,
std::unordered_map<std::string, ir::Tensor>* tensor_map,
std::vector<ir::Tensor>* op_func_arg_tensors) {
VLOG(4) << "Do lower with Compute, op: " << node->op()->name;
std::vector<common::CINNValue> cinn_inputs;
for (const ir::Tensor& tensor : *op_func_arg_tensors) {
cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor)));
}
// set tensor name = node data name
std::vector<NodeData*> node_datas = GetAllNodeData(node);
for (const NodeData* node_data : node_datas) {
cinn_inputs.push_back(common::CINNValue(node_data->id()));
}
// 1.Do compute
common::CINNValuePack pack =
op_impl->fcompute(common::CINNValuePack{cinn_inputs});
poly::StageMap tmp_stages = pack.back();
std::string post = "";
for (int idx = 0; idx < pack.size() - 1; ++idx) {
Expr expr = pack[idx];
// Insert the output tensor defined by Compute into the tensor_map
if (pack.size() - 1 > node_datas.size()) {
// Some nodes may output multiple temp tensors in their Compute
// definition, but only one output node_data in the graph, and we use id +
// "_0"/"_1" as key.
(*tensor_map)[node_datas[0]->id() + post] = expr.as_tensor_ref();
post = "_" + std::to_string(idx);
} else {
// If the number of output tensors defined by Compute is less equal than
// the output node_data on the graph, then there is a one-to-one
// correspondence, and the redundant output node_data contact empty.
(*tensor_map)[node_datas[idx]->id()] = expr.as_tensor_ref();
}
// Insert output tensors into function arg
if (!expr.as_tensor_ref()->buffer.defined() ||
this->target_ != common::DefaultNVGPUTarget()) {
op_func_arg_tensors->push_back(expr.as_tensor_ref());
expr.as_tensor_ref()->WithBuffer();
}
}
// 2.Do lower
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("fn_" + node->id(),
tmp_stages,
*op_func_arg_tensors,
{},
{},
nullptr,
this->target_,
true);
VLOG(4) << "Lower op: " << node->op()->name << ", get " << funcs.size()
<< " LoweredFunc:\n";
op_func_arg_tensors->clear();
for (int idx = 0; idx < pack.size() - 1; ++idx) {
CHECK(pack[idx].is_tensor());
op_func_arg_tensors->push_back(
pack[idx].operator ir::Expr().as_tensor_ref());
}
return funcs;
}
ir::Expr OpLowerer::DoOpSchedule(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
const std::vector<ir::Tensor>& op_func_arg_tensors,
const std::vector<ir::LoweredFunc>& lowered_funcs) {
VLOG(4) << "Do op schedule";
std::vector<common::CINNValue> schedule_inputs;
// 1.Collect tensors
for (const ir::Tensor& op_func_arg_tensor : op_func_arg_tensors) {
schedule_inputs.push_back(common::CINNValue(op_func_arg_tensor));
}
// 2.Collect bodies to be scheduled
for (const ir::LoweredFunc& func : lowered_funcs) {
schedule_inputs.push_back(common::CINNValue(func->body));
}
// 3.Do schedule on AST
common::CINNValuePack expr_pack =
op_impl->fschedule(common::CINNValuePack{schedule_inputs});
VLOG(4) << "After op schedule: " << expr_pack[0].operator ir::Expr();
return expr_pack[0].operator ir::Expr();
}
// group schedule
ir::Expr OpLowerer::DoGroupSchedule(
ir::IRSchedule& ir_sch,
const GroupPtr& group,
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
// topological order.
auto nodes_set = group->NodeSet();
auto v_consumers = BuildVirtualConsumer(group, this->shape_dict_);
auto nodes_in_order =
BFSTopologicalOrderWithPriority(group, v_consumers, this->shape_dict_);
// find reducer.
std::unordered_set<Node*> nodes_inline;
auto greducer = FindGlobalReducer(nodes_in_order);
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
// do schedule
for (auto node : nodes_in_order) {
VLOG(4) << "Try FUSION " << node->op()->name;
// consumers.
auto consumers = GetConsumersInSet(node, nodes_set);
const Node* reducer =
greducer ? FindNearestReducer(node, nodes_set) : greducer;
if (!reducer && greducer) {
reducer =
v_consumers.count(node) ? v_consumers.find(node)->second : reducer;
if (reducer && op_pattern_dict[reducer->op()] != framework::kReduction) {
reducer = nullptr;
}
}
auto masters = GetMasters(node, nodes_inline, nodes_set);
// node can be inline.
if (CanbeInline(node,
consumers,
reducer,
masters,
group,
nodes_set,
this->shape_dict_)) {
VLOG(3) << "Before compute inline, ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
auto block = ir_sch.GetBlock(GetNodeData(node)->id());
ir::ComputeInlineChecker checker(ir_sch, block);
if (!checker.Check()) {
checker.BuildDataDependency();
continue;
}
// if exist global reduce node.
if (greducer) {
auto loops = ir_sch.GetLoops(GetNodeData(node)->id());
if (op_pattern_dict[node->op()] == framework::kElementWise) {
ir_sch.FlattenLoops(loops, true);
} else {
ir_sch.FlattenLoops(loops, false);
}
}
ir_sch.ComputeInline(block);
nodes_inline.insert(node);
VLOG(3) << "After compute inline, ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
continue;
}
// find master to computeat.
auto master = GetMasterToComputeAt(node,
nodes_in_order,
nodes_inline,
nodes_set,
v_consumers,
this->shape_dict_);
// assign to reducer/master loop.
if (reducer) {
VLOG(3) << "Before assign node " << node->id()
<< " into vertical link reducer " << reducer->id() << ", ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
// if node is vertical with reduce, loop assign reducer.
LoopAssignReduce(
ir_sch, node, reducer, this->target_, tensor_map, this->shape_dict_);
} else if (greducer) {
auto greducer_out_shape =
this->shape_dict_.at(greducer->outlinks_in_order()[0]->sink()->id());
auto node_out_shape =
this->shape_dict_.at(node->outlinks_in_order()[0]->sink()->id());
if (std::accumulate(greducer_out_shape.begin(),
greducer_out_shape.end(),
1,
std::multiplies<int>()) !=
std::accumulate(node_out_shape.begin(),
node_out_shape.end(),
1,
std::multiplies<int>())) {
LoopAssignReduce(ir_sch,
node,
greducer,
this->target_,
tensor_map,
this->shape_dict_);
} else {
VLOG(3) << "Before assign node " << node->id()
<< " into horizontal link reducer " << greducer->id()
<< ", ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
// if node is horizontal with reduce or node is reduce, loop assign
//
// master.
auto loops = ir_sch.GetLoops(GetNodeData(node)->id());
if (op_pattern_dict[node->op()] == framework::kElementWise) {
ir_sch.FlattenLoops(loops, true);
} else if (op_pattern_dict[node->op()] != framework::kReduction) {
ir_sch.FlattenLoops(loops, false);
}
if (master && op_pattern_dict[node->op()] != framework::kReduction) {
auto master_loops = ir_sch.GetLoops(GetNodeData(master)->id());
std::vector<int> splits;
for (auto loop : master_loops) {
splits.push_back(loop.As<ir::For>()->extent.as_int32());
}
loops = ir_sch.GetLoops(GetNodeData(node)->id());
ir_sch.Split(loops[0], splits);
}
}
}
VLOG(3) << "Before loop fusion, ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
// do loop fuse.
LoopComputeAt(ir_sch,
node,
master ? master : nodes_in_order.front(),
group,
this->shape_dict_,
tensor_map);
VLOG(3) << "After loop fusion, ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
}
// do vectorize
auto all_blocks = ir_sch.GetAllBlocks();
VLOG(4) << "Size of blocks: " << all_blocks.size();
VLOG(4) << "Op Pattern : " << group->op_pattern_kind;
// only support first block?
auto block = all_blocks[0];
CHECK(block->as<ir::ScheduleBlockRealize>());
CHECK(block->as<ir::ScheduleBlockRealize>()
->schedule_block->as<ir::ScheduleBlock>());
auto is_tensor_block = true;
auto tensor_name = block->as<ir::ScheduleBlockRealize>()
->schedule_block->as<ir::ScheduleBlock>()
->name;
if (!tensor_map.count(tensor_name)) {
is_tensor_block = false;
}
if (FLAGS_cinn_use_cuda_vectorize && is_tensor_block &&
(group->op_pattern_kind == framework::kElementWise ||
group->op_pattern_kind == framework::kInjective ||
group->op_pattern_kind == framework::kBroadcast)) {
// auto loops = ir_sch.GetLoops(GetNodeData(node)->id());
auto loops = ir_sch.GetLoops(block);
VLOG(4) << "Op Pattern : " << loops.size();
if (loops.size() >= 1) {
VLOG(4) << "Before vectorize, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
auto loop_inner = loops.back();
int vector_width = 1;
auto psize = ir::GetLoopExtent(loop_inner);
// get dtype of vectorized var
auto dtype = this->type_dict_.at(tensor_name);
VLOG(4) << tensor_name << " dtype " << dtype;
if (psize % 8 == 0 && (dtype.is_float16() || dtype.is_bfloat16())) {
vector_width = 8;
} else if (psize % 4 == 0) {
vector_width = 4;
} else if (psize % 2 == 0) {
vector_width = 2;
}
if (vector_width > 1) {
auto splited = ir_sch.Split(loop_inner, {-1, vector_width});
splited[0].As<ir::For>()->set_bind_info(
loop_inner.As<ir::For>()->bind_info());
splited[1].As<ir::For>()->set_serial();
ir_sch.Vectorize(splited[1], vector_width);
}
VLOG(4) << "After vectorize, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
}
}
VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
SyncThreadWithShared(
ir_sch, group, nodes_inline, nodes_set, this->shape_dict_, tensor_map);
VLOG(4) << "After IRSchedule, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
return ir_sch.GetModule().GetExprs().at(0);
}
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -13,166 +13,78 @@
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/hlir/framework/op_lowering_impl.h"
#include "paddle/cinn/hlir/framework/op_lowering_impl_base.h"
#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h"
#include "paddle/cinn/lang/packed_func.h"
// Fusion Op lowering, there are four kinds of lowering function:
// Elementwise/Broadcast/Injective,Reduce,OutEWiseFusable,NonFusible.
// Elementwise/Broadcast/Injective Ops is with same shcedule.
// Reduce,OutEWiseFusable,NonFusible are using different schedule.
#ifndef CINN_WITH_ONLY
#include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h"
#endif
namespace cinn {
namespace hlir {
namespace framework {
using GroupPtr = std::shared_ptr<Graph::Group>;
using common::Target;
using GroupPtr = std::shared_ptr<hlir::framework::Graph::Group>;
class OpLowerer;
typedef bool (OpLowerer::*ScheduleDetermineFunction)(Node*);
template <typename T>
class OpLowerer {
public:
OpLowerer(const absl::flat_hash_map<std::string, Type>&,
const absl::flat_hash_map<std::string, shape_t>&,
const Target&);
explicit OpLowerer(OpLowererImplBase<T>* impl) { impl_.reset(impl); }
~OpLowerer() {}
/**
* @brief Lower a group to CINN IR.
* @param group The group to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param apply_group_schedule Whether to schedule at group level.
* @return The lowered funcs.
*/
std::vector<ir::LoweredFunc> Lower(const GroupPtr& group,
std::vector<ir::LoweredFunc> Lower(const T& group,
bool apply_op_schedule = true,
bool apply_group_schedule = true);
bool apply_group_schedule = true,
bool apply_pass = true) {
return impl_->Lower(
group, apply_op_schedule, apply_group_schedule, apply_pass);
}
std::vector<std::pair<ir::SymbolicPredicate, ir::LoweredFunc>> BucketLower(
const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) {
return impl_->BucketLower(
group, apply_op_schedule, apply_group_schedule, apply_pass);
}
private:
/**
* @brief Lower a group to CINN IR.
* @param group The group to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param apply_group_schedule Whether to schedule at group level.
* @param schedule_determine_func Function used to determine which Ops to
* schedule.
* @return The lowered funcs.
*/
std::vector<ir::LoweredFunc> LowerGroup(
const GroupPtr& group,
bool apply_op_schedule,
bool apply_group_schedule,
ScheduleDetermineFunction schedule_determine_func);
/**
* @brief Lower a group composed of CustomCall Op.
* @param group The group to be lowered.
* @return The lowered funcs.
*/
std::vector<ir::LoweredFunc> LowerCustomCall(const GroupPtr& group);
/**
* @brief Post processing, including preparing function args and temporary
* variables, applying low-level optimization passes, etc.
* @param group The group to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param done_op_schedule Mark whether the Op level schedule has been
* applied.
* @param ir_sch The IRSchedule object of group.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @return The lowered funcs after the post processing.
*/
std::vector<ir::LoweredFunc> PostProcess(
const GroupPtr& group,
const std::unordered_map<std::string, ir::Tensor>& tensor_map,
bool done_op_schedule,
ir::IRSchedule* ir_sch,
std::vector<ir::Tensor>* group_func_arg_tensors);
/**
* @brief Lower an Op set to CINN IR.
* Compute, Lower and optional Schedule will be performed one by one
* for each Op.
* @param nodes The Op nodes to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param schedule_determine_func Function used to determine which Ops to
* schedule.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func bodies of Op set.
*/
std::vector<ir::Expr> LowerOps(
const std::vector<Node*>& nodes,
bool apply_op_schedule,
ScheduleDetermineFunction schedule_determine_func,
std::vector<ir::Tensor>* group_func_arg_tensors,
std::unordered_map<std::string, ir::Tensor>* tensor_map);
/**
* @brief Lower an Op to CINN IR. The Compute and Lower processes will be
* called sequentially.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param node The Op node to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @return The lowered func of the Op node.
*/
std::vector<ir::LoweredFunc> DoOpLower(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
Node* node,
std::unordered_map<std::string, ir::Tensor>* tensor_map,
std::vector<ir::Tensor>* op_func_arg_tensors);
/**
* @brief Apply schedule on an Op.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @param lowered_funcs The lowered funcs of an Op to be scheduled.
* @return The lowered func body after schedule of the Op.
*/
ir::Expr DoOpSchedule(std::shared_ptr<hlir::framework::OpImpl> op_impl,
const std::vector<ir::Tensor>& op_func_arg_tensors,
const std::vector<ir::LoweredFunc>& lowered_funcs);
/**
* @brief Apply schedule on a group.
* @param ir_sch The IRSchedule containing the entire group's lowered func
* bodies.
* @param group The group to be scheduled.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func body after schedule of the group.
*/
ir::Expr DoGroupSchedule(
ir::IRSchedule& ir_sch, // NOLINT
const GroupPtr& group,
const std::unordered_map<std::string, ir::Tensor>& tensor_map);
// Functions used to determine which Ops to schedule at op level, define a
// policy for each type of group.
inline bool ReduceScheduleDetermineFunction(Node* node);
inline bool ElementwiseScheduleDetermineFunction(Node* node);
inline bool NonFusibleScheduleDetermineFunction(Node* node);
private:
Target target_;
const absl::flat_hash_map<std::string, Type>& type_dict_;
const absl::flat_hash_map<std::string, shape_t>& shape_dict_;
// fucntion name prefix
const std::string func_name_prefix = "fn_";
std::shared_ptr<OpLowererImplBase<T>> impl_;
};
template <typename T = GroupPtr>
OpLowerer<T> CreateOpLowerer(const absl::flat_hash_map<std::string, Type>&,
const absl::flat_hash_map<std::string, shape_t>&,
const Target&);
template <>
inline OpLowerer<GroupPtr> CreateOpLowerer(
const absl::flat_hash_map<std::string, Type>& type_dict,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const Target& target) {
auto* impl_base = new OpLowererImpl(type_dict, shape_dict, target);
return OpLowerer<GroupPtr>(impl_base);
}
#ifndef CINN_WITH_ONLY
template <typename T = pir::GroupPtr>
OpLowerer<T> CreateOpLowerer(const Target&);
template <>
inline OpLowerer<pir::GroupPtr> CreateOpLowerer(const Target& target) {
auto* impl_base = new pir::OpLowererImpl(target);
return OpLowerer<pir::GroupPtr>(impl_base);
}
#endif
} // namespace framework
} // namespace hlir
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment