Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -27,6 +27,7 @@
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "test/cpp/cinn/program_builder.h"
......@@ -44,11 +45,11 @@ std::vector<TuneTask> CreateTasks(const frontend::Program& program,
"inferdtype");
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
auto op_lowerer =
hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto i = 0; i < tasks.size(); ++i) {
tasks[i].Initialize(shape_dict, dtype_dict, op_lowerer.get());
tasks[i].Initialize(shape_dict, dtype_dict, &op_lowerer);
task_registry->Regist(tasks[i].serialized_key,
ir::ModuleExpr(tasks[i].GetLoweredFuncBodyExprs()));
}
......
......@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
......@@ -46,16 +47,13 @@ TEST(MutateTileSize, Basic) {
[&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = CreateStages({A, B, C});
ast_gen_ius::TensorGroup tensor_group({A, B, C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMutateTileSize_Basic",
stages,
lang::LowerToAstVec("TestMutateTileSize_Basic",
{A, B, C},
{},
{},
nullptr,
target,
true);
&tensor_group,
target);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Original Expr: ";
......@@ -65,7 +63,7 @@ TEST(MutateTileSize, Basic) {
// repeated.
utils::LinearRandomEngine::StateType rand_seed = 123;
ir::IRSchedule ir_schedule(module_expr, rand_seed);
ir::IRSchedule new_ir_schedule(ir_schedule);
ir::IRSchedule pir_schedule(ir_schedule);
// apply schedule
auto loops = ir_schedule.GetLoops("C");
......@@ -76,13 +74,13 @@ TEST(MutateTileSize, Basic) {
MutateTileSize mutator;
ir::ScheduleDesc sch_desc =
mutator.Apply(ir_schedule.GetTraceDesc(), &rand_seed);
sch_desc.Replay(&new_ir_schedule, true);
sch_desc.Replay(&pir_schedule, true);
VLOG(6) << "Expr before mutate tile size: \n"
<< ir_schedule.GetModule().GetExprs()[0];
VLOG(6) << "Expr after mutate tile size: \n"
<< new_ir_schedule.GetModule().GetExprs()[0];
<< pir_schedule.GetModule().GetExprs()[0];
std::string target_new_ir = R"ROC({
std::string target_pir = R"ROC({
ScheduleBlock(root)
{
serial for (i_1, 0, 2)
......@@ -117,7 +115,7 @@ TEST(MutateTileSize, Basic) {
ss << exprs[0];
return ss.str();
};
ASSERT_EQ(get_ir_str(&new_ir_schedule), target_new_ir);
ASSERT_EQ(get_ir_str(&pir_schedule), target_pir);
std::vector<int> last_tile_factors = {2, 16};
for (int i = 0; i < 10; ++i) {
......
......@@ -40,7 +40,7 @@
#include "paddle/cinn/backends/cuda_util.h"
#endif
DECLARE_bool(auto_schedule_use_cost_model);
PD_DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn {
namespace auto_schedule {
......@@ -247,7 +247,7 @@ TaskOptimizer::Result TaskOptimizer::OptimizeByEvolution(
auto& optimized_funcs = result.functions;
auto& best_cost = result.cost;
// use initial lowered function as default result
optimized_funcs = optim::IRCopy(task_->lowered_funcs);
optimized_funcs = ir::ir_utils::IRCopy(task_->lowered_funcs);
if (options.num_measure_trials ==
0) { // no need to measure and simply return the best searched
std::vector<MeasureInput> measure_candidates;
......@@ -347,7 +347,7 @@ std::vector<SearchState> TaskOptimizer::SearchOneRound(
CHECK_EQ(best_exprs.size(), task_->lowered_funcs.size())
<< "RuntimeError: Expr size is not equal to LoweredFunc size in "
"TaskOptimizer";
auto init_funcs = optim::IRCopy(task_->lowered_funcs);
auto init_funcs = ir::ir_utils::IRCopy(task_->lowered_funcs);
std::vector<ir::LoweredFunc> valid_funcs;
for (size_t j = 0; j < best_exprs.size(); ++j) {
auto updated_f =
......
......@@ -14,14 +14,13 @@
#pragma once
#include <gflags/gflags.h>
#include <mutex>
#include <string>
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/utils/registry.h"
#include "paddle/utils/flags.h"
namespace cinn {
......@@ -64,7 +63,7 @@ class InitialTaskRegistry : public Registry<InitialTaskInfo> {
std::lock_guard<std::mutex> guard(registering_mutex);
if (fmap_.count(task_key) == 0) {
InitialTaskInfo* task_info =
new InitialTaskInfo(task_key, optim::IRCopy(module_expr));
new InitialTaskInfo(task_key, ir::ir_utils::IRCopy(module_expr));
__REGISTER__(task_key, task_info);
}
}
......
......@@ -28,7 +28,7 @@
#include "paddle/cinn/utils/string.h"
#include "paddle/cinn/utils/type_defs.h"
DECLARE_bool(auto_schedule_use_cost_model);
PD_DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn {
namespace auto_schedule {
......@@ -45,11 +45,10 @@ std::vector<TuneTask> CreateTasks(hlir::framework::Graph* graph,
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
std::unique_ptr<hlir::framework::OpLowerer> op_lowerer =
std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
auto op_lowerer =
hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target);
for (TuneTask& task : tasks) {
task.Initialize(shape_dict, dtype_dict, op_lowerer.get());
task.Initialize(shape_dict, dtype_dict, &op_lowerer);
VLOG(3) << "Add a task with serialized_key:\n" << task.serialized_key;
}
......
......@@ -34,7 +34,7 @@ void TuneTask::Initialize(
const absl::flat_hash_map<std::string, hlir::framework::shape_t>&
shape_dict,
const absl::flat_hash_map<std::string, cinn::common::Type>& dtype_dict,
hlir::framework::OpLowerer* lower_handler) {
hlir::framework::OpLowerer<GroupPtr>* lower_handler) {
CHECK(lower_handler != nullptr) << "op_lowerer can't be nullptr";
op_lowerer = lower_handler;
......
......@@ -34,16 +34,17 @@ namespace cinn {
namespace auto_schedule {
class TuneTask {
using GroupPtr = hlir::framework::GroupPtr;
public:
TuneTask() = default;
explicit TuneTask(std::shared_ptr<hlir::framework::Graph::Group> group)
: subgraph(group) {}
explicit TuneTask(GroupPtr group) : subgraph(group) {}
// Initialize a task
void Initialize(
const absl::flat_hash_map<std::string, hlir::framework::shape_t>&
shape_dict,
const absl::flat_hash_map<std::string, cinn::common::Type>& dtype_dict,
hlir::framework::OpLowerer* lower_handler);
hlir::framework::OpLowerer<GroupPtr>* lower_handler);
// Extract bodies in lowered_funcs() and return
std::vector<ir::Expr> GetLoweredFuncBodyExprs() const;
......@@ -51,7 +52,7 @@ class TuneTask {
// sub-graph (if an op won't be fused, it will be a Group with size=1).
std::shared_ptr<hlir::framework::Graph::Group> subgraph;
// Lower handler, Not owned
hlir::framework::OpLowerer* op_lowerer;
hlir::framework::OpLowerer<GroupPtr>* op_lowerer;
// target of this task
common::Target target;
// stores the initial (un-optimized) LoweredFuncs
......
......@@ -31,8 +31,8 @@
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......@@ -75,7 +75,8 @@ TEST(TuneTask, GraphToUnoptLoweredFunc_NoPass) {
const auto& dtype_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
OpLowerer op_lowerer(dtype_dict, shape_dict, target);
auto op_lowerer =
hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target);
std::stringstream ss;
for (TuneTask& task : tasks) {
......@@ -187,7 +188,8 @@ TEST(TuneTask, GraphToUnoptLoweredFunc_ApplyPass) {
graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
OpLowerer op_lowerer(dtype_dict, shape_dict, target);
OpLowerer op_lowerer(
new hlir::framework::OpLowererImpl(dtype_dict, shape_dict, target));
std::stringstream ss;
for (TuneTask& task : tasks) {
......@@ -291,7 +293,8 @@ TEST(TuneTask, SerializeToString) {
const auto& dtype_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
OpLowerer op_lowerer(dtype_dict, shape_dict, target);
OpLowerer op_lowerer(
new hlir::framework::OpLowererImpl(dtype_dict, shape_dict, target));
ASSERT_EQ(single_tasks.size(), 2UL);
for (auto&& task : single_tasks) {
task.Initialize(shape_dict, dtype_dict, &op_lowerer);
......
......@@ -25,7 +25,9 @@
#include "paddle/cinn/frontend/paddle_model_convertor.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/runtime/flags.h"
......@@ -42,7 +44,7 @@
* parameters for more detail.
*/
DEFINE_string(resnet50_model_dir,
PD_DEFINE_string(resnet50_model_dir,
"./ResNet50",
"the path to paddle model resnet50.");
// Flags that control which schedule tests will be run.
......@@ -52,15 +54,16 @@ DEFINE_string(resnet50_model_dir,
// auto schedule test, means options = 4 = "100" will run auto schedule test.
// The default value is -1, which means that this flag is disabled to set the
// options
DEFINE_int32(evaluate_knobs,
PD_DEFINE_int32(evaluate_knobs,
-1,
"the options to control which schedule tests will be run.");
DECLARE_double(cinn_infer_model_version);
PD_DECLARE_double(cinn_infer_model_version);
namespace cinn {
namespace auto_schedule {
using ::cinn::hlir::framework::BuildScope;
using ::cinn::hlir::framework::CompilationContext;
using ::cinn::hlir::framework::Graph;
using ::cinn::hlir::framework::GraphCompiler;
using ::cinn::hlir::framework::Instruction;
......@@ -94,8 +97,8 @@ class PerformanceTester : public ::testing::Test {
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
VLOG(3) << "Build " << schedule_name << " program.";
auto scope = BuildScope(target_, graph);
auto graph_compiler =
std::make_unique<GraphCompiler>(target_, scope, graph);
CompilationContext context(graph, scope, target_);
auto graph_compiler = std::make_unique<GraphCompiler>(context);
auto runtime_program =
(this->*build_fn)(graph.get(), graph_compiler.get());
if (execute) {
......@@ -141,28 +144,27 @@ class PerformanceTester : public ::testing::Test {
absl::flat_hash_map<std::string, hlir::framework::shape_t>>(
"infershape");
std::shared_ptr<hlir::framework::OpLowerer> op_lowerer =
std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target_);
auto op_lowerer =
hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target_);
GraphCompiler::CompileOptions compile_options;
compile_options.with_instantiate_variables = true;
CompilationContext& context = graph_compiler->GetCompilationContext();
context.with_instantiate_variables = true;
if (graph->fusion_groups.empty()) {
hlir::framework::ApplyPasses(graph, {"BuildNonFusedGroupsPass"});
}
compile_options.groups = graph->fusion_groups;
context.groups = graph->fusion_groups;
for (auto group : graph->fusion_groups) {
compile_options.lowered_funcs.push_back(
op_lowerer->Lower(group,
context.lowered_funcs.push_back(
op_lowerer.Lower(group,
/*apply_op_schedule = */ false,
/*apply_group_schedule=*/false));
}
VLOG(3) << "===========================No Schedule LoweredFunc "
"Begin===========================";
for (const auto& funcvec : compile_options.lowered_funcs) {
for (const auto& funcvec : context.lowered_funcs) {
for (const auto& func : funcvec) {
VLOG(3) << func;
}
......@@ -170,7 +172,7 @@ class PerformanceTester : public ::testing::Test {
VLOG(3) << "===========================No Schedule LoweredFunc "
"End=============================";
return graph_compiler->Build(compile_options).runtime_program;
return graph_compiler->Build();
}
std::unique_ptr<hlir::framework::Program> BuildManualScheduleProgram(
......@@ -191,13 +193,13 @@ class PerformanceTester : public ::testing::Test {
tuner->Initialize(tuning_config, graph_compiler);
TuningResult tuning_result = tuner->Tune(tuning_options);
GraphCompiler::CompileOptions compile_options;
compile_options.with_instantiate_variables = true;
compile_options.Apply(tuning_result);
CompilationContext& context = graph_compiler->GetCompilationContext();
context.with_instantiate_variables = true;
context.ApplyTuningResult(tuning_result);
VLOG(3) << "===========================Auto Schedule LoweredFunc "
"Begin===========================";
for (const auto& funcvec : compile_options.lowered_funcs) {
for (const auto& funcvec : context.lowered_funcs) {
for (const auto& func : funcvec) {
VLOG(3) << func;
}
......@@ -205,7 +207,7 @@ class PerformanceTester : public ::testing::Test {
VLOG(3) << "===========================Auto Schedule LoweredFunc "
"End=============================";
return graph_compiler->Build(compile_options).runtime_program;
return graph_compiler->Build();
}
#ifdef CINN_WITH_CUDA
......
......@@ -23,13 +23,12 @@
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_verify.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/runtime/cpu/thread_backend.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
//! Root of the builtin code.
DECLARE_string(cinn_x86_builtin_code_root);
PD_DECLARE_string(cinn_x86_builtin_code_root);
namespace cinn {
namespace backends {
......@@ -39,7 +38,7 @@ using cinn::common::float16;
const char *kCKeywordRestrict = "__restrict__";
void CodeGenC::Compile(const ir::Module &module, const Outputs &outputs) {
ir::IrVerify(Expr(module));
ir::ir_utils::IrVerify(Expr(module));
if (!outputs.c_header_name.empty()) {
auto source = Compile(module, OutputKind::CHeader);
......@@ -286,31 +285,13 @@ void CodeGenC::Visit(const ir::Select *op) {
void CodeGenC::Visit(const ir::IfThenElse *op) {
str_ += "if (";
IrPrinter::Visit(op->condition);
str_ += ") {\n";
str_ += ") ";
if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent();
IrPrinter::Visit(op->true_case);
if (!op->true_case.As<ir::Block>()) str_ += ";";
str_ += "\n";
if (!op->true_case.As<ir::Block>()) DecIndent();
DoIndent();
str_ += "}";
if (op->false_case.defined()) {
str_ += " else {\n";
if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent();
str_ += " else ";
IrPrinter::Visit(op->false_case);
if (!op->false_case.As<ir::Block>()) str_ += ";";
str_ += "\n";
if (!op->true_case.As<ir::Block>()) DecIndent();
DoIndent();
str_ += "}";
}
}
void CodeGenC::Visit(const ir::Block *op) {
......@@ -645,7 +626,7 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) {
Expr func_body = ir::Block::Make(new_body);
optim::RemoveNestedBlock(&func_body);
optim::SimplifyBlocks(&func_body);
IrPrinter::Visit(func_body);
}
......@@ -766,6 +747,7 @@ void CodeGenC::Visit(const ir::ScheduleBlock *op) { CINN_NOT_IMPLEMENTED }
void CodeGenC::Visit(const ir::ScheduleBlockRealize *op) {
CINN_NOT_IMPLEMENTED
}
void CodeGenC::Visit(const ir::_Dim_ *op) { CINN_NOT_IMPLEMENTED }
void CodeGenC::Visit(const ir::IntrinsicOp *op) {
switch (op->getKind()) {
......
......@@ -14,19 +14,18 @@
#pragma once
#include <gflags/gflags.h>
#include <string>
#include <vector>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/intrinsic_ops.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/packed_func.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
#include "paddle/utils/flags.h"
namespace cinn {
......
......@@ -19,6 +19,7 @@
#include <sstream>
#include <tuple>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/module.h"
......@@ -65,8 +66,10 @@ TEST(CodeGenC, module) {
target.os = Target::OS ::Linux;
Module::Builder builder("module1", target);
auto stages = CreateStages({A, B, C});
auto func = Lower("add1", stages, {A, B, C});
ast_gen_ius::TensorGroup tensor_group({A, B, C});
auto func = lang::LowerToAst("add1", {A, B, C}, &tensor_group);
LOG(INFO) << "Func to codegen: " << func << std::endl;
builder.AddFunction(func);
......@@ -74,7 +77,7 @@ TEST(CodeGenC, module) {
CodeGenC codegen(target);
codegen.SetInlineBuiltinCodes(false);
auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl);
std::cout << "codegen C:" << std::endl << out << std::endl;
LOG(INFO) << "codegen C:" << std::endl << out << std::endl;
std::string target_str = R"ROC(
#include <cinn_runtime.h>
......
......@@ -24,7 +24,6 @@
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_verify.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/remove_nested_block.h"
namespace cinn {
namespace backends {
......@@ -57,7 +56,7 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, bool for_nvrtc) {
void CodeGenCUDA_Dev::Compile(const ir::Module &module,
const Outputs &outputs) {
ir::IrVerify(Expr(module));
ir::ir_utils::IrVerify(Expr(module));
CodeGenC::inline_builtin_codes_ = false;
if (!outputs.c_header_name.empty()) {
......@@ -91,7 +90,7 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(
temp_buffers.end());
// prepare temp buffer alias
std::vector<Expr> buffer_alias;
auto tensors = ir::CollectIRNodes(op->body, [&](const Expr *x) {
auto tensors = ir::ir_utils::CollectIRNodes(op->body, [&](const Expr *x) {
return x->as_tensor() && x->as_tensor()->buffer.defined() &&
temp_buffer_set.count(x->as_tensor()->buffer);
});
......@@ -141,7 +140,7 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
Expr func_body = ir::Block::Make(new_body);
optim::RemoveNestedBlock(&func_body);
optim::SimplifyBlocks(&func_body);
// Make sure that the function's body is wrapped by a block
if (!func_body.As<ir::Block>()) {
func_body = ir::Block::Make({func_body});
......
......@@ -20,9 +20,9 @@
#include "paddle/cinn/backends/codegen_c.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/packed_func.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
......
......@@ -30,8 +30,8 @@
#include "paddle/cinn/common/test_helper.h"
#include "paddle/cinn/hlir/pe/nn.h"
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/utils/timer.h"
......
......@@ -182,5 +182,159 @@ llvm::Value* CodeGenCUDA_Host::LowerGPUKernelLauncher(
return function;
}
llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) {
// Create the function
// @{
auto* function_type = GenFunctionTypeFromCinnFunction(func, true);
f_ = llvm::Function::Create(
function_type, llvm::Function::ExternalLinkage, func->name, m_);
f_->setCallingConv(llvm::CallingConv::C);
f_->setHasUWTable();
std::vector<llvm::Value*> ll_function_args;
std::transform(f_->arg_begin(),
f_->arg_end(),
std::back_inserter(ll_function_args),
[](auto& arg) { return std::addressof(arg); });
// @}
llvm::BasicBlock* entry = llvm::BasicBlock::Create(
/*Context=*/b_->getContext(),
/*Name=*/"entry",
/*Parent=*/f_,
/*InsertBefore=*/nullptr);
b_->SetInsertPoint(entry);
CodeGenLLVM::Visit(&func->body);
RetVoid();
return f_;
}
llvm::Value* CodeGenCUDA_Host::LowerCUDAKernelCall(const ir::Call* call_ir) {
std::vector<llvm::Value*> ll_function_args;
std::transform(f_->arg_begin(),
f_->arg_end(),
std::back_inserter(ll_function_args),
[](auto& arg) { return std::addressof(arg); });
auto* kernel_args = ll_function_args[0];
auto* kernel_args_count = ll_function_args[1];
llvm::Value* kernel_stream = nullptr;
if (ll_function_args.size() == 3) {
kernel_stream = ll_function_args[2];
CHECK_EQ(kernel_stream->getType(), ll_void_p_ty()); // void* stream
}
CHECK_EQ(kernel_args->getType(), ll_void_p_ty()); // void* args
CHECK_EQ(kernel_args_count->getType(), ll_int32_ty()); // int32
std::unordered_map<std::string, llvm::Value*> global_args = {
{KERNEL_ARGS, kernel_args},
{KERNEL_ARGS_NUM, kernel_args_count},
{KERNEL_STREAM, kernel_stream}};
auto ret_type = CinnTypeToLLVMType(Void(), m_);
std::vector<llvm::Type*> args_type;
for (auto r_arg : call_ir->read_args) {
if (r_arg.is_var()) {
if (r_arg.as_var()->type().is_cpp_handle() ||
r_arg.as_var()->type().is_string()) {
args_type.push_back(CinnTypeToLLVMType(type_of<void*>(), m_));
} else if (r_arg.as_var()->type().is_int(32)) {
args_type.push_back(CinnTypeToLLVMType(type_of<int32_t>(), m_));
} else {
CINN_NOT_IMPLEMENTED;
}
} else {
if (r_arg.type().is_bool()) {
args_type.push_back(CinnTypeToLLVMType(type_of<bool>(), m_));
} else if (r_arg.type().is_uint(8)) {
args_type.push_back(CinnTypeToLLVMType(type_of<uint8_t>(), m_));
} else if (r_arg.type().is_uint(16)) {
args_type.push_back(CinnTypeToLLVMType(type_of<uint16_t>(), m_));
} else if (r_arg.type().is_uint(32)) {
args_type.push_back(CinnTypeToLLVMType(type_of<uint32_t>(), m_));
} else if (r_arg.type().is_uint(64)) {
args_type.push_back(CinnTypeToLLVMType(type_of<uint64_t>(), m_));
} else if (r_arg.type().is_int(8)) {
args_type.push_back(CinnTypeToLLVMType(type_of<int8_t>(), m_));
} else if (r_arg.type().is_int(16)) {
args_type.push_back(CinnTypeToLLVMType(type_of<int16_t>(), m_));
} else if (r_arg.type().is_int(32)) {
args_type.push_back(CinnTypeToLLVMType(type_of<int32_t>(), m_));
} else if (r_arg.type().is_int(64)) {
args_type.push_back(CinnTypeToLLVMType(type_of<int64_t>(), m_));
} else if (r_arg.type().is_float(32)) {
args_type.push_back(CinnTypeToLLVMType(type_of<float>(), m_));
} else if (r_arg.type().is_float(64)) {
args_type.push_back(CinnTypeToLLVMType(type_of<double>(), m_));
} else if (r_arg.type().is_bfloat16()) {
args_type.push_back(CinnTypeToLLVMType(type_of<bfloat16>(), m_));
} else if (r_arg.type().is_float16()) {
args_type.push_back(CinnTypeToLLVMType(type_of<float16>(), m_));
} else {
CINN_NOT_IMPLEMENTED;
}
}
}
auto func_type = llvm::FunctionType::get(ret_type, args_type, false);
auto call_func = m_->getOrInsertFunction(call_ir->name, func_type);
std::vector<llvm::Value*> call_args;
for (auto& r_arg : call_ir->read_args) {
if (r_arg.is_var()) {
if (r_arg.as_var()->type().is_string()) {
auto kvalue = m_->getOrInsertGlobal(r_arg.as_var()->name + "_ptr_",
b_->getInt8PtrTy());
call_args.push_back(b_->CreateLoad(
b_->getInt8PtrTy(), kvalue, r_arg.as_var()->name + "_ptr_load"));
} else if (r_arg.as_var()->type().is_cpp_handle() ||
r_arg.as_var()->type().is_int(32)) {
CHECK(global_args.count(r_arg.as_var()->name));
call_args.push_back(global_args[r_arg.as_var()->name]);
} else {
CINN_NOT_IMPLEMENTED;
}
} else {
if (r_arg.type().is_bool()) {
call_args.push_back(b_->getInt1(r_arg.as_bool()));
} else if (r_arg.type().is_int(8)) {
call_args.push_back(b_->getInt8(r_arg.as_int8()));
} else if (r_arg.type().is_int(16)) {
call_args.push_back(b_->getInt16(r_arg.as_int16()));
} else if (r_arg.type().is_int(32)) {
call_args.push_back(b_->getInt32(r_arg.as_int32()));
} else if (r_arg.type().is_int(64)) {
call_args.push_back(b_->getInt64(r_arg.as_int64()));
} else if (r_arg.type().is_uint(8)) {
call_args.push_back(b_->getInt8(r_arg.as_uint8()));
} else if (r_arg.type().is_uint(16)) {
call_args.push_back(b_->getInt16(r_arg.as_uint16()));
} else if (r_arg.type().is_uint(32)) {
call_args.push_back(b_->getInt32(r_arg.as_uint32()));
} else if (r_arg.type().is_uint(64)) {
call_args.push_back(b_->getInt64(r_arg.as_uint64()));
} else if (r_arg.type().is_float(32)) {
call_args.push_back(llvm::ConstantFP::get(
b_->getFloatTy(), llvm::APFloat(r_arg.as_float())));
} else if (r_arg.type().is_float(64)) {
call_args.push_back(llvm::ConstantFP::get(
b_->getDoubleTy(), llvm::APFloat(r_arg.as_double())));
} else if (r_arg.type().is_bfloat16()) {
call_args.push_back(llvm::ConstantFP::get(
b_->getBFloatTy(),
llvm::APFloat(static_cast<float>(r_arg.as_bfloat16()))));
} else if (r_arg.type().is_float16()) {
call_args.push_back(llvm::ConstantFP::get(
b_->getHalfTy(),
llvm::APFloat(static_cast<float>(r_arg.as_float16()))));
} else {
CINN_NOT_IMPLEMENTED;
}
}
}
b_->CreateCall(call_func, call_args);
return nullptr;
}
} // namespace backends
} // namespace cinn
......@@ -23,6 +23,8 @@
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
PD_DECLARE_bool(cinn_bucket_compile);
namespace cinn {
namespace backends {
......@@ -38,9 +40,16 @@ class CodeGenCUDA_Host : public CodeGenLLVM {
using CodeGenLLVM::Visit;
llvm::Value *Visit(const ir::_LoweredFunc_ *func) override {
if (FLAGS_cinn_bucket_compile) {
return LowerHostFunc(func);
}
return LowerGPUKernelLauncher(func);
}
llvm::Value *Visit(const ir::Call *op) override {
return LowerCUDAKernelCall(op);
}
private:
/**
* Lower a CUDA kernel launcher.
......@@ -56,6 +65,10 @@ class CodeGenCUDA_Host : public CodeGenLLVM {
*
*/
llvm::Value *LowerGPUKernelLauncher(const ir::_LoweredFunc_ *func);
llvm::Value *LowerHostFunc(const ir::_LoweredFunc_ *func);
llvm::Value *LowerCUDAKernelCall(const ir::Call *op);
};
} // namespace backends
......
......@@ -15,16 +15,106 @@
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
PD_DECLARE_bool(cinn_bucket_compile);
namespace cinn {
namespace backends {
std::tuple<ir::Module, ir::Module> SplitCudaAndHostModule(ir::Module module) {
if (FLAGS_cinn_bucket_compile) {
detail::CollectBucketStrategyHostFunctionVisitor visitor(module->name);
Expr expr(module);
return visitor(&expr);
}
detail::CollectHostFunctionVisitor visitor(module->name);
Expr expr(module);
return visitor(&expr);
}
struct PredicatePrinter : public ir::IrPrinter {
explicit PredicatePrinter(std::ostream &os) : ir::IrPrinter(os) {}
private:
void Visit(const ir::Add *x) { PrintBinaryOp("ADD", x); }
void Visit(const ir::Sub *x) { PrintBinaryOp("SUB", x); }
void Visit(const ir::Mul *x) { PrintBinaryOp("MUL", x); }
void Visit(const ir::Div *x) { PrintBinaryOp("DIV", x); }
void Visit(const ir::Mod *x) { PrintBinaryOp("MOD", x); }
void Visit(const ir::EQ *x) { PrintBinaryOp("EQ", x); }
void Visit(const ir::NE *x) { PrintBinaryOp("NE", x); }
void Visit(const ir::LT *x) { PrintBinaryOp("LT", x); }
void Visit(const ir::LE *x) { PrintBinaryOp("LE", x); }
void Visit(const ir::GT *x) { PrintBinaryOp("GT", x); }
void Visit(const ir::GE *x) { PrintBinaryOp("GE", x); }
void Visit(const ir::And *x) { PrintBinaryOp("AND", x); }
void Visit(const ir::Or *x) { PrintBinaryOp("OR", x); }
template <typename IRN>
void PrintBinaryOp(const std::string &op, const ir::BinaryOpNode<IRN> *x) {
str_ += "_FPA_";
ir::IrPrinter::Visit(x->a());
str_ += op;
ir::IrPrinter::Visit(x->b());
str_ += "_BPA_";
}
};
std::string Predicate2String(ir::Expr predicate) {
std::stringstream ss;
PredicatePrinter cond_printer(ss);
cond_printer.Print(predicate);
return ss.str();
}
std::string
detail::CollectBucketStrategyHostFunctionVisitor::GenDeviceKernelName(
const std::string &fn_name, ir::Expr predicate) {
std::string cond_str = Predicate2String(predicate);
VLOG(3) << "predicate string: " << cond_str;
return fn_name + "__COND_" + cond_str + "__kernel";
}
void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
ir::Expr func, ir::Expr predicate) {
ir::_LoweredFunc_ *func_node = func.as_lowered_func();
CHECK(func_node);
if (!func_node->cuda_axis_info.valid()) {
func_node->cuda_axis_info.set_valid(true);
}
// process device func
device_module_builder.AddFunctionWithoutOptim(
CreateDeviceFunction(func, predicate).as_lowered_func_ref());
// process host func
ir::Var kernel_ptr(GenDeviceKernelName(func_node->name, predicate),
type_of<std::string>());
ir::Expr call_extern_api =
ir::Call::Make(Void(),
runtime::intrinsic::call_cuda_kernel,
{kernel_ptr,
kernel_args_,
kernel_args_num_,
Expr(func_node->cuda_axis_info.grid_dim(0)), // grid_x
Expr(func_node->cuda_axis_info.grid_dim(1)), // grid_y
Expr(func_node->cuda_axis_info.grid_dim(2)), // grid_z
Expr(func_node->cuda_axis_info.block_dim(0)), // block_x
Expr(func_node->cuda_axis_info.block_dim(1)), // block_y
Expr(func_node->cuda_axis_info.block_dim(2)), // block_z
kernel_stream_},
{},
ir::CallType::Extern,
ir::FunctionRef(),
0);
buckets_.emplace_back(ir::IfThenElse::Make(predicate, call_extern_api));
}
Expr detail::CollectBucketStrategyHostFunctionVisitor::CreateDeviceFunction(
ir::Expr expr, ir::Expr predicate) {
auto copied = ir::ir_utils::IRCopy(expr);
auto *lowered_func = copied.as_lowered_func();
lowered_func->name = GenDeviceKernelName(lowered_func->name, predicate);
return copied;
}
} // namespace backends
} // namespace cinn
......@@ -22,8 +22,8 @@
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
namespace cinn {
namespace backends {
......@@ -57,7 +57,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
device_module_builder.Build());
}
private:
protected:
void Visit(const ir::_LoweredFunc_* op, Expr* expr) override {
if (op->body.As<ir::Call>()) {
host_module_builder.AddFunctionWithoutOptim(expr->as_lowered_func_ref());
......@@ -127,7 +127,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
}
Expr CreateDeviceFunctionGivenDeviceKernel(Expr expr) {
auto copied = optim::IRCopy(expr);
auto copied = ir::ir_utils::IRCopy(expr);
auto* lowered_func = copied.as_lowered_func();
lowered_func->name = GenDeviceKernelName(lowered_func->name);
return copied;
......@@ -137,11 +137,61 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
return fn + "_kernel";
}
private:
protected:
ir::Module::Builder host_module_builder;
ir::Module::Builder device_module_builder;
};
struct CollectBucketStrategyHostFunctionVisitor
: public CollectHostFunctionVisitor {
explicit CollectBucketStrategyHostFunctionVisitor(
const std::string& module_name)
: CollectHostFunctionVisitor(module_name),
kernel_args_(KERNEL_ARGS, type_of<void*>()),
kernel_args_num_(KERNEL_ARGS_NUM, type_of<int>()),
kernel_stream_(KERNEL_STREAM, type_of<void*>()) {}
std::tuple<ir::Module, ir::Module> operator()(Expr* expr) {
ir::IRMutator<>::Visit(expr, expr);
return std::make_tuple(host_module_builder.Build(),
device_module_builder.Build());
}
private:
void Visit(const ir::_Module_* op, Expr* expr) {
CHECK_EQ(op->functions.size(), op->predicates.size());
for (int i = 0; i < op->functions.size(); ++i) {
ProcessLoweredFunc(op->functions[i], op->predicates[i]);
}
std::vector<ir::Argument> arguments = {
ir::Argument(kernel_args_, ir::Argument::IO::kOutput),
ir::Argument(kernel_args_num_, ir::Argument::IO::kInput),
ir::Argument(kernel_stream_, ir::Argument::IO::kOutput)};
ir::Expr host_func =
ir::_LoweredFunc_::Make(op->functions[0].as_lowered_func()->name,
arguments,
ir::Block::Make(buckets_),
{});
host_module_builder.AddFunctionWithoutOptim(
host_func.as_lowered_func_ref());
}
void ProcessLoweredFunc(ir::Expr func, ir::Expr predicate);
Expr CreateDeviceFunction(ir::Expr expr, ir::Expr predicate);
inline std::string GenDeviceKernelName(const std::string& fn_name,
ir::Expr predicate);
private:
std::vector<ir::Expr> buckets_;
ir::Var kernel_args_;
ir::Var kernel_args_num_;
ir::Var kernel_stream_;
};
} // namespace detail
} // namespace backends
......
......@@ -18,7 +18,9 @@
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/ir/ir_printer.h"
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
......@@ -29,27 +31,83 @@
#include "paddle/cinn/runtime/flags.h"
#endif
DECLARE_string(cinn_source_code_save_path);
DECLARE_string(cinn_dump_group_lowered_func);
DECLARE_string(cinn_dump_group_source_code);
DECLARE_string(cinn_dump_group_ptx);
DECLARE_string(cinn_dump_group_instruction);
PD_DECLARE_string(cinn_source_code_save_path);
PD_DECLARE_string(cinn_dump_group_lowered_func);
PD_DECLARE_string(cinn_dump_group_source_code);
PD_DECLARE_string(cinn_dump_group_ptx);
PD_DECLARE_string(cinn_dump_group_instruction);
namespace cinn {
namespace backends {
using ir::Module;
using CompilationStatus = hlir::framework::CompilationStatus;
static constexpr int DebugLogMaxLen = 30000;
void CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
const ir::LoweredFunc& lowered_func, const int gidx, const int device_id) {
if (FLAGS_cinn_dump_group_lowered_func.empty() ||
lowered_func.get() == nullptr) {
return;
}
std::stringstream content;
content << lowered_func;
Dump(FLAGS_cinn_dump_group_lowered_func,
gidx,
device_id,
"lowered_function.txt",
content.str());
}
void CompilationInfoDumper::DumpSourceCodeByGroupIndex(
const std::string& source_code, const int gidx, const int device_id) {
if (FLAGS_cinn_dump_group_source_code.empty()) {
return;
}
Dump(FLAGS_cinn_dump_group_source_code,
gidx,
device_id,
"source_code.cu",
source_code);
}
void CompilationInfoDumper::DumpPtxCodeByGroupIndex(
const std::string& source_ptx, const int gidx, const int device_id) {
if (FLAGS_cinn_dump_group_ptx.empty()) {
return;
}
Dump(
FLAGS_cinn_dump_group_ptx, gidx, device_id, "source_ptx.ptx", source_ptx);
}
void CompilationInfoDumper::DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx,
const int device_id) {
if (FLAGS_cinn_dump_group_instruction.empty() || instr.get() == nullptr) {
return;
}
Dump(FLAGS_cinn_dump_group_instruction,
gidx,
device_id,
"instruction.txt",
instr->DumpInstruction());
}
void CompilationInfoDumper::DumpLoweredFunc() {
if (FLAGS_cinn_dump_group_lowered_func.empty()) {
return;
}
for (int idx = 0; idx < info_.lowered_funcs.size(); ++idx) {
for (int idx = 0; idx < info_.Size(); ++idx) {
std::stringstream content;
content << info_.lowered_funcs[idx].front();
if (info_.Status(idx) > CompilationStatus::LOWERING_FAIL) {
content << info_.LoweredFuncs(idx).front();
} else {
content << "[No lowered func generated]\n\n" << info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_lowered_func,
idx,
device_id_,
"lowered_function.txt",
content.str());
}
......@@ -59,11 +117,18 @@ void CompilationInfoDumper::DumpSourceCode() {
if (FLAGS_cinn_dump_group_source_code.empty()) {
return;
}
for (int idx = 0; idx < info_.source_codes.size(); ++idx) {
for (int idx = 0; idx < info_.Size(); ++idx) {
std::string dump_str;
if (info_.Status(idx) > CompilationStatus::CODEGEN_JIT_FAIL) {
dump_str = info_.SourceCode(idx);
} else {
dump_str = "[No source code generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_source_code,
idx,
device_id_,
"source_code.cu",
info_.source_codes[idx]);
dump_str);
}
}
......@@ -71,11 +136,15 @@ void CompilationInfoDumper::DumpPtxCode() {
if (FLAGS_cinn_dump_group_ptx.empty()) {
return;
}
for (int idx = 0; idx < info_.source_ptxs.size(); ++idx) {
Dump(FLAGS_cinn_dump_group_ptx,
idx,
"source_ptx.ptx",
info_.source_ptxs[idx]);
for (int idx = 0; idx < info_.Size(); ++idx) {
std::string dump_str;
if (info_.Status(idx) > CompilationStatus::CODEGEN_JIT_FAIL) {
dump_str = info_.SourcePtx(idx);
} else {
dump_str = "[No source ptxs generated]\n\n" + info_.Message(idx);
}
Dump(
FLAGS_cinn_dump_group_ptx, idx, device_id_, "source_ptx.ptx", dump_str);
}
}
......@@ -83,20 +152,28 @@ void CompilationInfoDumper::DumpInstruction() {
if (FLAGS_cinn_dump_group_instruction.empty()) {
return;
}
for (int idx = 0; idx < info_.instructions.size(); ++idx) {
for (int idx = 0; idx < info_.RuntimeInstructions().size(); ++idx) {
std::string dump_str;
if (info_.RuntimeInstruction(idx).get() != nullptr) {
dump_str = info_.RuntimeInstruction(idx)->DumpInstruction();
} else {
dump_str = "[No instruction generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_instruction,
idx,
device_id_,
"instruction.txt",
info_.instructions[idx]->DumpInstruction());
dump_str);
}
}
void CompilationInfoDumper::Dump(const std::string& base_path,
const int idx,
const int device_id,
const std::string& file_name,
const std::string& content) {
auto dump_path =
utils::StringFormat("%s/fusion_group_%d", base_path.c_str(), idx);
auto dump_path = utils::StringFormat(
"%s/device_%d/fusion_group_%d", base_path.c_str(), device_id, idx);
if (!hlir::framework::MakeDirectory(
dump_path, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) {
LOG(WARNING) << "Failed to make directory: \"" << dump_path
......@@ -227,6 +304,8 @@ void Compiler::CompileCudaModule(const Module& module,
auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name);
CHECK(fn_kernel);
fn_ptr_.push_back(reinterpret_cast<void*>(fn_kernel));
symbols.RegisterVar(kernel_fn_name + "_ptr_",
reinterpret_cast<void*>(fn_kernel));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment