Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
...@@ -43,26 +43,42 @@ namespace backends { ...@@ -43,26 +43,42 @@ namespace backends {
*/ */
class CompilationInfoDumper { class CompilationInfoDumper {
public: public:
explicit CompilationInfoDumper( explicit CompilationInfoDumper(const hlir::framework::CompilationResult& info,
const hlir::framework::ParallelCompiler::CompilationResult& info) const int device_id)
: info_(info) { : info_(info), device_id_(device_id) {
DumpLoweredFunc(); DumpLoweredFunc();
DumpSourceCode(); DumpSourceCode();
DumpPtxCode(); DumpPtxCode();
DumpInstruction(); DumpInstruction();
} }
static void DumpLoweredFuncByGroupIndex(const ir::LoweredFunc& lowered_func,
const int gidx,
const int device_id);
static void DumpSourceCodeByGroupIndex(const std::string& source_code,
const int gidx,
const int device_id);
static void DumpPtxCodeByGroupIndex(const std::string& source_ptx,
const int gidx,
const int device_id);
static void DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx,
const int device_id);
private: private:
void DumpLoweredFunc(); void DumpLoweredFunc();
void DumpSourceCode(); void DumpSourceCode();
void DumpPtxCode(); void DumpPtxCode();
void DumpInstruction(); void DumpInstruction();
void Dump(const std::string& base_path, static void Dump(const std::string& base_path,
const int idx, const int idx,
const std::string& file_name, const int device_id,
const std::string& content); const std::string& file_name,
const std::string& content);
const hlir::framework::ParallelCompiler::CompilationResult& info_;
const hlir::framework::CompilationResult& info_;
const int device_id_;
}; };
class SourceCodePrint { class SourceCodePrint {
...@@ -105,6 +121,8 @@ class Compiler final { ...@@ -105,6 +121,8 @@ class Compiler final {
*/ */
void* Lookup(absl::string_view fn_name); void* Lookup(absl::string_view fn_name);
std::vector<void*> GetFnPtr() const { return fn_ptr_; }
private: private:
void CompileCudaModule(const ir::Module& module, void CompileCudaModule(const ir::Module& module,
const std::string& code = ""); const std::string& code = "");
...@@ -120,6 +138,7 @@ class Compiler final { ...@@ -120,6 +138,7 @@ class Compiler final {
Target target_; Target target_;
std::unique_ptr<ExecutionEngine> engine_; std::unique_ptr<ExecutionEngine> engine_;
std::vector<void*> fn_ptr_;
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_; std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif #endif
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/string.h"
DECLARE_bool(verbose_function_register); PD_DECLARE_bool(verbose_function_register);
namespace cinn { namespace cinn {
namespace backends { namespace backends {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/runtime/flags.h"
DECLARE_bool(verbose_function_register); PD_DECLARE_bool(verbose_function_register);
namespace cinn { namespace cinn {
namespace backends { namespace backends {
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include "paddle/cinn/backends/codegen_c_x86.h" #include "paddle/cinn/backends/codegen_c_x86.h"
#include "paddle/cinn/backends/codegen_cuda_dev.h" #include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/cinn.h" #include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_error.h" #include "paddle/cinn/ir/schedule/ir_schedule_error.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/lower.h" #include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/remove_schedule_block.h" #include "paddle/cinn/optim/remove_schedule_block.h"
...@@ -690,6 +690,7 @@ void test_unroll(void* _args, int32_t num_args) ...@@ -690,6 +690,7 @@ void test_unroll(void* _args, int32_t num_args)
ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code));
} }
#ifdef CINN_WITH_CUDA
TEST(IrSchedule, bind) { TEST(IrSchedule, bind) {
Context::Global().ResetNameId(); Context::Global().ResetNameId();
Expr M(32); Expr M(32);
...@@ -733,6 +734,7 @@ function test_bind (_A, _B) ...@@ -733,6 +734,7 @@ function test_bind (_A, _B)
} }
)ROC")); )ROC"));
} }
#endif
TEST(IrSchedule, simple_compute_at) { TEST(IrSchedule, simple_compute_at) {
Context::Global().ResetNameId(); Context::Global().ResetNameId();
...@@ -794,10 +796,8 @@ void test_simple_compute_at(void* _args, int32_t num_args) ...@@ -794,10 +796,8 @@ void test_simple_compute_at(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) { for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) { for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) { if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)]; B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)]; C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
}; };
}; };
}; };
...@@ -869,10 +869,8 @@ void test_compute_at0(void* _args, int32_t num_args) ...@@ -869,10 +869,8 @@ void test_compute_at0(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) { for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) { for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) { if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)]; B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)]; C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
}; };
}; };
}; };
...@@ -2314,6 +2312,270 @@ void test_rfactor(void* _args, int32_t num_args) ...@@ -2314,6 +2312,270 @@ void test_rfactor(void* _args, int32_t num_args)
ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code));
} }
TEST(IrSchedule, factorize_reduction) {
Context::Global().ResetNameId();
Expr M(3);
Expr N(4);
Expr K(5);
Target target = common::DefaultHostTarget();
Placeholder<float> A("A", {M, N, K});
Var j(4, "j0");
Var k(5, "k0");
auto B = Compute(
{M},
[&](Var i) {
return lang::ReduceSum(A(i, j, k), {j, k});
},
"B");
auto stages = CreateStages({A, B});
auto func = cinn::lang::LowerVec("test_factorize_reduction",
stages,
{A, B},
{},
{},
nullptr,
target,
true);
CHECK(!func.empty());
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
auto loops = ir_sch.GetLoops("B");
CHECK_EQ(loops.size(), 3U);
auto new_rf_tensor = ir_sch.FactorizeReduction(loops[1], 0);
auto* new_rf_tensor_ref = new_rf_tensor.As<ir::_Tensor_>();
CHECK(new_rf_tensor_ref);
CHECK(new_rf_tensor_ref->buffer.defined());
func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer);
func[0]->PrepareBufferCastExprs();
std::string origin = utils::GetStreamCnt(func[0]);
LOG(INFO) << origin;
EXPECT_EQ(origin, utils::Trim(R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[vj0, i0_0] = 0.00000000f
}
serial for (k0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, i2 = axis.bind(j0, i, k0)
B_rf[vj0, i0_0] = (B_rf[vj0, i0_0] + A[i0_0, vj0, i2])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[vj0, i0_0])
}
}
}
}
}
}
)ROC"));
}
TEST(IrSchedule, factorize_reduction1) {
Context::Global().ResetNameId();
Expr M(3);
Expr N(4);
Expr K(5);
Target target = common::DefaultHostTarget();
Placeholder<float> A("A", {M, N, K});
Var j(4, "j0");
Var k(5, "k0");
auto B = Compute(
{M},
[&](Var i) {
return lang::ReduceSum(A(i, j, k), {j, k});
},
"B");
auto stages = CreateStages({A, B});
auto func = cinn::lang::LowerVec("test_factorize_reduction",
stages,
{A, B},
{},
{},
nullptr,
target,
true);
CHECK(!func.empty());
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
auto loops = ir_sch.GetLoops("B");
CHECK_EQ(loops.size(), 3U);
auto new_rf_tensor = ir_sch.FactorizeReduction(loops[1], 1);
auto* new_rf_tensor_ref = new_rf_tensor.As<ir::_Tensor_>();
CHECK(new_rf_tensor_ref);
CHECK(new_rf_tensor_ref->buffer.defined());
func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer);
func[0]->PrepareBufferCastExprs();
std::string origin = utils::GetStreamCnt(func[0]);
LOG(INFO) << origin;
EXPECT_EQ(origin, utils::Trim(R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[i0_0, vj0] = 0.00000000f
}
serial for (k0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, i2 = axis.bind(j0, i, k0)
B_rf[i0_0, vj0] = (B_rf[i0_0, vj0] + A[i0_0, vj0, i2])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[i0_0, vj0])
}
}
}
}
}
}
)ROC"));
}
TEST(IrSchedule, factorize_reduction2) {
Context::Global().ResetNameId();
Expr M(3);
Expr N(4);
Expr K(5);
Target target = common::DefaultHostTarget();
Placeholder<float> A("A", {M, N * K});
Var j(4 * 5, "j0");
auto B = Compute(
{M}, [&](Var i) { return lang::ReduceSum(A(i, j), {j}); }, "B");
auto stages = CreateStages({A, B});
auto func = cinn::lang::LowerVec("test_factorize_reduction",
stages,
{A, B},
{},
{},
nullptr,
target,
true);
CHECK(!func.empty());
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
auto loops = ir_sch.GetLoops("B");
CHECK_EQ(loops.size(), 2U);
auto splited_loops = ir_sch.Split(loops[1], {4, 5});
CHECK_EQ(splited_loops.size(), 2U);
auto new_rf_tensor = ir_sch.FactorizeReduction(splited_loops[0], 1);
auto* new_rf_tensor_ref = new_rf_tensor.As<ir::_Tensor_>();
CHECK(new_rf_tensor_ref);
CHECK(new_rf_tensor_ref->buffer.defined());
func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer);
func[0]->PrepareBufferCastExprs();
std::string origin = utils::GetStreamCnt(func[0]);
LOG(INFO) << origin;
EXPECT_EQ(origin, utils::Trim(R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[i0_0, vj0] = 0.00000000f
}
serial for (j0_0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, vj0_0 = axis.bind(j0, i, j0_0)
B_rf[i0_0, vj0] = (B_rf[i0_0, vj0] + A[i0_0, ((5 * vj0) + vj0_0)])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[i0_0, vj0])
}
}
}
}
}
}
)ROC"));
}
TEST(IrSchedule, compute_inline1) { TEST(IrSchedule, compute_inline1) {
Context::Global().ResetNameId(); Context::Global().ResetNameId();
Expr M(32); Expr M(32);
......
...@@ -43,8 +43,8 @@ ...@@ -43,8 +43,8 @@
#include "paddle/cinn/backends/llvm/llvm_util.h" #include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/type.h" #include "paddle/cinn/common/type.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_verify.h" #include "paddle/cinn/ir/utils/ir_verify.h"
#include "paddle/cinn/optim/var_mod_simplify.h" #include "paddle/cinn/optim/var_mod_simplify.h"
#include "paddle/cinn/runtime/cinn_runtime.h" #include "paddle/cinn/runtime/cinn_runtime.h"
...@@ -747,6 +747,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlock *) { ...@@ -747,6 +747,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlock *) {
llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlockRealize *) { llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlockRealize *) {
CINN_NOT_IMPLEMENTED return nullptr; CINN_NOT_IMPLEMENTED return nullptr;
} }
llvm::Value *CodeGenLLVM::Visit(const ir::_Dim_ *) {
CINN_NOT_IMPLEMENTED return nullptr;
}
llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) { llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
if (op->name == runtime::intrinsic::debug_log_repr) { if (op->name == runtime::intrinsic::debug_log_repr) {
...@@ -790,7 +793,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) { ...@@ -790,7 +793,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
llvm::Value *CodeGenLLVM::Visit(const ir::_Module_ *op) { llvm::Value *CodeGenLLVM::Visit(const ir::_Module_ *op) {
{ {
Expr body_to_verify(&Reference(op)); Expr body_to_verify(&Reference(op));
ir::IrVerify(body_to_verify); ir::ir_utils::IrVerify(body_to_verify);
} }
for (auto &fn : op->functions) { for (auto &fn : op->functions) {
......
...@@ -32,9 +32,9 @@ ...@@ -32,9 +32,9 @@
#include "paddle/cinn/backends/llvm/ir_builder_mixin.h" #include "paddle/cinn/backends/llvm/ir_builder_mixin.h"
#include "paddle/cinn/backends/llvm/llvm_util.h" #include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/ir/intrinsic_ops.h" #include "paddle/cinn/ir/intrinsic_ops.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/lowered_func.h" #include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h" #include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace cinn { namespace cinn {
namespace backends { namespace backends {
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include "paddle/cinn/common/target.h" #include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/optim/collect_undefined_vars.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/runtime/intrinsic.h" #include "paddle/cinn/runtime/intrinsic.h"
namespace cinn::backends { namespace cinn::backends {
...@@ -98,7 +98,7 @@ void CodeGenX86::CreateParallelLaunch(Expr body, int num_task) { ...@@ -98,7 +98,7 @@ void CodeGenX86::CreateParallelLaunch(Expr body, int num_task) {
llvm::Function::PrivateLinkage, llvm::Function::PrivateLinkage,
"__parallel_lambda", "__parallel_lambda",
m_); m_);
std::vector<std::string> vars = optim::CollectUndefinedVars(&body); std::vector<std::string> vars = ir::ir_utils::CollectUndefinedVars(&body);
uint64_t nbytes; uint64_t nbytes;
auto* data = PackVars(vars, &nbytes); auto* data = PackVars(vars, &nbytes);
......
...@@ -61,7 +61,7 @@ ...@@ -61,7 +61,7 @@
#include "paddle/cinn/backends/llvm/llvm_optimizer.h" #include "paddle/cinn/backends/llvm/llvm_optimizer.h"
#include "paddle/cinn/backends/llvm/llvm_util.h" #include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h" #include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h" #include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/profiler.h" #include "paddle/cinn/utils/profiler.h"
......
...@@ -41,8 +41,8 @@ ...@@ -41,8 +41,8 @@
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h" #include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/cinn.h" #include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/module.h" #include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h" #include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h" #include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h" #include "paddle/cinn/lang/placeholder.h"
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include <iostream> #include <iostream>
#include "gflags/gflags_declare.h"
#include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/runtime/flags.h"
#include "paddle/utils/flags.h"
DECLARE_bool(verbose_function_register); PD_DECLARE_bool(verbose_function_register);
namespace cinn { namespace cinn {
namespace backends { namespace backends {
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
#include "paddle/cinn/backends/llvm/codegen_llvm.h" #include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/llvm_util.h" #include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h" #include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h" #include "paddle/cinn/runtime/intrinsic.h"
namespace cinn { namespace cinn {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "paddle/cinn/backends/modular.h" #include "paddle/cinn/backends/modular.h"
#include "paddle/cinn/ir/utils/ir_visitor.h" #include "paddle/cinn/ir/ir_visitor.h"
namespace cinn { namespace cinn {
namespace backends { namespace backends {
......
...@@ -30,8 +30,9 @@ ...@@ -30,8 +30,9 @@
#include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/string.h"
DECLARE_string(cinn_nvcc_cmd_path); PD_DECLARE_string(cinn_nvcc_cmd_path);
DECLARE_bool(nvrtc_compile_to_cubin); PD_DECLARE_bool(nvrtc_compile_to_cubin);
PD_DECLARE_bool(cinn_nvrtc_cubin_with_fmad);
namespace cinn { namespace cinn {
namespace backends { namespace backends {
...@@ -106,6 +107,9 @@ std::string Compiler::CompileCudaSource(const std::string& code, ...@@ -106,6 +107,9 @@ std::string Compiler::CompileCudaSource(const std::string& code,
} }
if (compile_to_cubin_) { if (compile_to_cubin_) {
compile_options.push_back("-arch=sm_" + cc); compile_options.push_back("-arch=sm_" + cc);
std::string enable_fmad =
FLAGS_cinn_nvrtc_cubin_with_fmad ? "true" : "false";
compile_options.push_back("--fmad=" + enable_fmad);
} else { } else {
compile_options.push_back("-arch=compute_" + cc); compile_options.push_back("-arch=compute_" + cc);
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
namespace cinn { namespace cinn {
using ast_gen_ius::TensorGroup;
using backends::CodeGenC; using backends::CodeGenC;
using backends::CodeGenCX86; using backends::CodeGenCX86;
using backends::Outputs; using backends::Outputs;
...@@ -39,6 +40,7 @@ using lang::CallExtern; ...@@ -39,6 +40,7 @@ using lang::CallExtern;
using lang::CallLowered; using lang::CallLowered;
using lang::Compute; using lang::Compute;
using lang::Lower; using lang::Lower;
using lang::LowerToAst;
using lang::Placeholder; using lang::Placeholder;
using lang::ReduceAll; using lang::ReduceAll;
using lang::ReduceAny; using lang::ReduceAny;
......
...@@ -19,10 +19,13 @@ gather_srcs( ...@@ -19,10 +19,13 @@ gather_srcs(
arithmatic.cc arithmatic.cc
cas.cc cas.cc
union_find.cc union_find.cc
python_interpreter_guard.cc) python_interpreter_guard.cc
nvgpu_dev_info.cc)
message(STATUS "srcs: ${cinnapi_src}") message(STATUS "srcs: ${cinnapi_src}")
cinn_cc_test(test_equation_graph_topo_walker SRCS
equation_graph_topo_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_dfs_walker SRCS dfs_walker_test.cc DEPS gtest glog) cinn_cc_test(test_dfs_walker SRCS dfs_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_dfs_topo_walker SRCS dfs_topo_walker_test.cc DEPS gtest glog) cinn_cc_test(test_dfs_topo_walker SRCS dfs_topo_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_is_reachable_predicator SRCS is_reachable_predicator_test.cc cinn_cc_test(test_is_reachable_predicator SRCS is_reachable_predicator_test.cc
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
#include <string> #include <string>
#include "paddle/cinn/common/ir_util.h" #include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/string.h"
namespace cinn { namespace cinn {
...@@ -126,7 +126,7 @@ GiNaC::ex ExprToGinacConverter::BuildHelper(ir::Expr expr) { ...@@ -126,7 +126,7 @@ GiNaC::ex ExprToGinacConverter::BuildHelper(ir::Expr expr) {
GiNaC::ex ExprToGinacConverter::operator()(Expr expr) { GiNaC::ex ExprToGinacConverter::operator()(Expr expr) {
// TODO(Superjomn) Replace this with common::IsPureMath( // TODO(Superjomn) Replace this with common::IsPureMath(
auto complex_nodes = CollectIRNodes(expr, [](const Expr* n) { auto complex_nodes = ir::ir_utils::CollectIRNodes(expr, [](const Expr* n) {
return n->As<Block>() || // return n->As<Block>() || //
n->As<PolyFor>() || // n->As<PolyFor>() || //
n->As<EQ>() || // n->As<EQ>() || //
...@@ -262,7 +262,7 @@ bool IsPureMath(Expr expr) { ...@@ -262,7 +262,7 @@ bool IsPureMath(Expr expr) {
IrNodeTy ::Minus, IrNodeTy ::Minus,
}); });
auto complex_nodes = ir::CollectIRNodes(expr, [&](const Expr* n) { auto complex_nodes = ir::ir_utils::CollectIRNodes(expr, [&](const Expr* n) {
return !valid_node_tys.count(n->node_type()); return !valid_node_tys.count(n->node_type());
}); });
#ifdef CINN_DEBUG #ifdef CINN_DEBUG
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#include "paddle/cinn/common/ir_util.h" #include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/string.h"
namespace cinn { namespace cinn {
......
...@@ -21,13 +21,12 @@ ...@@ -21,13 +21,12 @@
#include "paddle/cinn/common/arithmatic.h" #include "paddle/cinn/common/arithmatic.h"
#include "paddle/cinn/common/ir_util.h" #include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/string.h"
namespace cinn { namespace cinn {
...@@ -1585,7 +1584,7 @@ bool CASasSymbol(Expr expr) { ...@@ -1585,7 +1584,7 @@ bool CASasSymbol(Expr expr) {
Expr ConvertCinnToCAS(Expr expr) { Expr ConvertCinnToCAS(Expr expr) {
VLOG(7) << "Begin ConvertCinnToCAS " << expr; VLOG(7) << "Begin ConvertCinnToCAS " << expr;
Expr copied = optim::IRCopy(expr); Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> { struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); } void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
...@@ -1711,7 +1710,7 @@ Expr ConvertCinnToCAS(Expr expr) { ...@@ -1711,7 +1710,7 @@ Expr ConvertCinnToCAS(Expr expr) {
* simplify the condition ensures correctness, though not sufficient. * simplify the condition ensures correctness, though not sufficient.
*/ */
Expr ReplaceMinToConstant(Expr expr) { Expr ReplaceMinToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr); Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> { struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); } void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
...@@ -1728,10 +1727,10 @@ Expr ReplaceMinToConstant(Expr expr) { ...@@ -1728,10 +1727,10 @@ Expr ReplaceMinToConstant(Expr expr) {
auto min_b = op->b(); auto min_b = op->b();
if (min_a.is_constant() && !min_b.is_constant()) { if (min_a.is_constant() && !min_b.is_constant()) {
CHECK(min_a->type().is_integer()); CHECK(min_a->type().is_integer());
*expr = optim::IRCopy(min_a); *expr = ir::ir_utils::IRCopy(min_a);
} else if (min_b.is_constant() && !min_a.is_constant()) { } else if (min_b.is_constant() && !min_a.is_constant()) {
CHECK(min_b->type().is_integer()); CHECK(min_b->type().is_integer());
*expr = optim::IRCopy(min_b); *expr = ir::ir_utils::IRCopy(min_b);
} }
} }
}; };
...@@ -1744,7 +1743,7 @@ Expr ReplaceMinToConstant(Expr expr) { ...@@ -1744,7 +1743,7 @@ Expr ReplaceMinToConstant(Expr expr) {
* constant value and 1 inconstant value, return the constant max value. * constant value and 1 inconstant value, return the constant max value.
*/ */
Expr ReplaceMaxToConstant(Expr expr) { Expr ReplaceMaxToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr); Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> { struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); } void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
...@@ -1761,10 +1760,10 @@ Expr ReplaceMaxToConstant(Expr expr) { ...@@ -1761,10 +1760,10 @@ Expr ReplaceMaxToConstant(Expr expr) {
auto max_b = op->b(); auto max_b = op->b();
if (max_a.is_constant() && !max_b.is_constant()) { if (max_a.is_constant() && !max_b.is_constant()) {
CHECK(max_a->type().is_integer()); CHECK(max_a->type().is_integer());
*expr = optim::IRCopy(max_a); *expr = ir::ir_utils::IRCopy(max_a);
} else if (max_b.is_constant() && !max_a.is_constant()) { } else if (max_b.is_constant() && !max_a.is_constant()) {
CHECK(max_b->type().is_integer()); CHECK(max_b->type().is_integer());
*expr = optim::IRCopy(max_b); *expr = ir::ir_utils::IRCopy(max_b);
} }
} }
}; };
...@@ -1774,7 +1773,7 @@ Expr ReplaceMaxToConstant(Expr expr) { ...@@ -1774,7 +1773,7 @@ Expr ReplaceMaxToConstant(Expr expr) {
Expr ConvertCasToCinn(Expr expr) { Expr ConvertCasToCinn(Expr expr) {
VLOG(7) << "Begin ConvertCasToCinn : " << expr; VLOG(7) << "Begin ConvertCasToCinn : " << expr;
Expr copied = optim::IRCopy(expr); Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : ir::IRMutator<Expr*> { struct Mutator : ir::IRMutator<Expr*> {
void operator()(Expr* expr) { Visit(expr); } void operator()(Expr* expr) { Visit(expr); }
...@@ -1869,7 +1868,7 @@ bool IsExprCasCompatible(Expr expr) { ...@@ -1869,7 +1868,7 @@ bool IsExprCasCompatible(Expr expr) {
return expr->As<Add>() || expr->As<Sub>() || expr->As<Mul>() || return expr->As<Add>() || expr->As<Sub>() || expr->As<Mul>() ||
expr->As<Div>(); expr->As<Div>();
}; };
return ir::CollectIRNodes(expr, teller).empty(); return ir::ir_utils::CollectIRNodes(expr, teller).empty();
} }
// Partially divide a by b. e.g. (2x+y)/2 => x + y/2 // Partially divide a by b. e.g. (2x+y)/2 => x + y/2
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <vector> #include <vector>
#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/ir_simplify.h"
namespace cinn { namespace cinn {
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "paddle/cinn/cinn.h" #include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/common.h" #include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h" #include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/string.h"
namespace cinn { namespace cinn {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment