Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -43,26 +43,42 @@ namespace backends {
*/
class CompilationInfoDumper {
public:
explicit CompilationInfoDumper(
const hlir::framework::ParallelCompiler::CompilationResult& info)
: info_(info) {
explicit CompilationInfoDumper(const hlir::framework::CompilationResult& info,
const int device_id)
: info_(info), device_id_(device_id) {
DumpLoweredFunc();
DumpSourceCode();
DumpPtxCode();
DumpInstruction();
}
static void DumpLoweredFuncByGroupIndex(const ir::LoweredFunc& lowered_func,
const int gidx,
const int device_id);
static void DumpSourceCodeByGroupIndex(const std::string& source_code,
const int gidx,
const int device_id);
static void DumpPtxCodeByGroupIndex(const std::string& source_ptx,
const int gidx,
const int device_id);
static void DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx,
const int device_id);
private:
void DumpLoweredFunc();
void DumpSourceCode();
void DumpPtxCode();
void DumpInstruction();
void Dump(const std::string& base_path,
const int idx,
const std::string& file_name,
const std::string& content);
const hlir::framework::ParallelCompiler::CompilationResult& info_;
static void Dump(const std::string& base_path,
const int idx,
const int device_id,
const std::string& file_name,
const std::string& content);
const hlir::framework::CompilationResult& info_;
const int device_id_;
};
class SourceCodePrint {
......@@ -105,6 +121,8 @@ class Compiler final {
*/
void* Lookup(absl::string_view fn_name);
std::vector<void*> GetFnPtr() const { return fn_ptr_; }
private:
void CompileCudaModule(const ir::Module& module,
const std::string& code = "");
......@@ -120,6 +138,7 @@ class Compiler final {
Target target_;
std::unique_ptr<ExecutionEngine> engine_;
std::vector<void*> fn_ptr_;
#ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif
......
......@@ -27,7 +27,7 @@
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
DECLARE_bool(verbose_function_register);
PD_DECLARE_bool(verbose_function_register);
namespace cinn {
namespace backends {
......
......@@ -21,7 +21,7 @@
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(verbose_function_register);
PD_DECLARE_bool(verbose_function_register);
namespace cinn {
namespace backends {
......
......@@ -24,8 +24,8 @@
#include "paddle/cinn/backends/codegen_c_x86.h"
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_error.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/remove_schedule_block.h"
......@@ -690,6 +690,7 @@ void test_unroll(void* _args, int32_t num_args)
ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code));
}
#ifdef CINN_WITH_CUDA
TEST(IrSchedule, bind) {
Context::Global().ResetNameId();
Expr M(32);
......@@ -733,6 +734,7 @@ function test_bind (_A, _B)
}
)ROC"));
}
#endif
TEST(IrSchedule, simple_compute_at) {
Context::Global().ResetNameId();
......@@ -794,10 +796,8 @@ void test_simple_compute_at(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
};
};
};
......@@ -869,10 +869,8 @@ void test_compute_at0(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
};
};
};
......@@ -2314,6 +2312,270 @@ void test_rfactor(void* _args, int32_t num_args)
ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code));
}
TEST(IrSchedule, factorize_reduction) {
Context::Global().ResetNameId();
Expr M(3);
Expr N(4);
Expr K(5);
Target target = common::DefaultHostTarget();
Placeholder<float> A("A", {M, N, K});
Var j(4, "j0");
Var k(5, "k0");
auto B = Compute(
{M},
[&](Var i) {
return lang::ReduceSum(A(i, j, k), {j, k});
},
"B");
auto stages = CreateStages({A, B});
auto func = cinn::lang::LowerVec("test_factorize_reduction",
stages,
{A, B},
{},
{},
nullptr,
target,
true);
CHECK(!func.empty());
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
auto loops = ir_sch.GetLoops("B");
CHECK_EQ(loops.size(), 3U);
auto new_rf_tensor = ir_sch.FactorizeReduction(loops[1], 0);
auto* new_rf_tensor_ref = new_rf_tensor.As<ir::_Tensor_>();
CHECK(new_rf_tensor_ref);
CHECK(new_rf_tensor_ref->buffer.defined());
func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer);
func[0]->PrepareBufferCastExprs();
std::string origin = utils::GetStreamCnt(func[0]);
LOG(INFO) << origin;
EXPECT_EQ(origin, utils::Trim(R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[vj0, i0_0] = 0.00000000f
}
serial for (k0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, i2 = axis.bind(j0, i, k0)
B_rf[vj0, i0_0] = (B_rf[vj0, i0_0] + A[i0_0, vj0, i2])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[vj0, i0_0])
}
}
}
}
}
}
)ROC"));
}
TEST(IrSchedule, factorize_reduction1) {
Context::Global().ResetNameId();
Expr M(3);
Expr N(4);
Expr K(5);
Target target = common::DefaultHostTarget();
Placeholder<float> A("A", {M, N, K});
Var j(4, "j0");
Var k(5, "k0");
auto B = Compute(
{M},
[&](Var i) {
return lang::ReduceSum(A(i, j, k), {j, k});
},
"B");
auto stages = CreateStages({A, B});
auto func = cinn::lang::LowerVec("test_factorize_reduction",
stages,
{A, B},
{},
{},
nullptr,
target,
true);
CHECK(!func.empty());
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
auto loops = ir_sch.GetLoops("B");
CHECK_EQ(loops.size(), 3U);
auto new_rf_tensor = ir_sch.FactorizeReduction(loops[1], 1);
auto* new_rf_tensor_ref = new_rf_tensor.As<ir::_Tensor_>();
CHECK(new_rf_tensor_ref);
CHECK(new_rf_tensor_ref->buffer.defined());
func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer);
func[0]->PrepareBufferCastExprs();
std::string origin = utils::GetStreamCnt(func[0]);
LOG(INFO) << origin;
EXPECT_EQ(origin, utils::Trim(R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[i0_0, vj0] = 0.00000000f
}
serial for (k0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, i2 = axis.bind(j0, i, k0)
B_rf[i0_0, vj0] = (B_rf[i0_0, vj0] + A[i0_0, vj0, i2])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[i0_0, vj0])
}
}
}
}
}
}
)ROC"));
}
TEST(IrSchedule, factorize_reduction2) {
Context::Global().ResetNameId();
Expr M(3);
Expr N(4);
Expr K(5);
Target target = common::DefaultHostTarget();
Placeholder<float> A("A", {M, N * K});
Var j(4 * 5, "j0");
auto B = Compute(
{M}, [&](Var i) { return lang::ReduceSum(A(i, j), {j}); }, "B");
auto stages = CreateStages({A, B});
auto func = cinn::lang::LowerVec("test_factorize_reduction",
stages,
{A, B},
{},
{},
nullptr,
target,
true);
CHECK(!func.empty());
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
auto loops = ir_sch.GetLoops("B");
CHECK_EQ(loops.size(), 2U);
auto splited_loops = ir_sch.Split(loops[1], {4, 5});
CHECK_EQ(splited_loops.size(), 2U);
auto new_rf_tensor = ir_sch.FactorizeReduction(splited_loops[0], 1);
auto* new_rf_tensor_ref = new_rf_tensor.As<ir::_Tensor_>();
CHECK(new_rf_tensor_ref);
CHECK(new_rf_tensor_ref->buffer.defined());
func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer);
func[0]->PrepareBufferCastExprs();
std::string origin = utils::GetStreamCnt(func[0]);
LOG(INFO) << origin;
EXPECT_EQ(origin, utils::Trim(R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[i0_0, vj0] = 0.00000000f
}
serial for (j0_0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, vj0_0 = axis.bind(j0, i, j0_0)
B_rf[i0_0, vj0] = (B_rf[i0_0, vj0] + A[i0_0, ((5 * vj0) + vj0_0)])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[i0_0, vj0])
}
}
}
}
}
}
)ROC"));
}
TEST(IrSchedule, compute_inline1) {
Context::Global().ResetNameId();
Expr M(32);
......
......@@ -43,8 +43,8 @@
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_verify.h"
#include "paddle/cinn/optim/var_mod_simplify.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
......@@ -747,6 +747,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlock *) {
llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlockRealize *) {
CINN_NOT_IMPLEMENTED return nullptr;
}
llvm::Value *CodeGenLLVM::Visit(const ir::_Dim_ *) {
CINN_NOT_IMPLEMENTED return nullptr;
}
llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
if (op->name == runtime::intrinsic::debug_log_repr) {
......@@ -790,7 +793,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
llvm::Value *CodeGenLLVM::Visit(const ir::_Module_ *op) {
{
Expr body_to_verify(&Reference(op));
ir::IrVerify(body_to_verify);
ir::ir_utils::IrVerify(body_to_verify);
}
for (auto &fn : op->functions) {
......
......@@ -32,9 +32,9 @@
#include "paddle/cinn/backends/llvm/ir_builder_mixin.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/ir/intrinsic_ops.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace cinn {
namespace backends {
......
......@@ -28,7 +28,7 @@
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/optim/collect_undefined_vars.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace cinn::backends {
......@@ -98,7 +98,7 @@ void CodeGenX86::CreateParallelLaunch(Expr body, int num_task) {
llvm::Function::PrivateLinkage,
"__parallel_lambda",
m_);
std::vector<std::string> vars = optim::CollectUndefinedVars(&body);
std::vector<std::string> vars = ir::ir_utils::CollectUndefinedVars(&body);
uint64_t nbytes;
auto* data = PackVars(vars, &nbytes);
......
......@@ -61,7 +61,7 @@
#include "paddle/cinn/backends/llvm/llvm_optimizer.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/profiler.h"
......
......@@ -41,8 +41,8 @@
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
......
......@@ -19,10 +19,10 @@
#include <iostream>
#include "gflags/gflags_declare.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/utils/flags.h"
DECLARE_bool(verbose_function_register);
PD_DECLARE_bool(verbose_function_register);
namespace cinn {
namespace backends {
......
......@@ -37,7 +37,7 @@
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace cinn {
......
......@@ -14,7 +14,7 @@
#include "paddle/cinn/backends/modular.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/ir/ir_visitor.h"
namespace cinn {
namespace backends {
......
......@@ -30,8 +30,9 @@
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
DECLARE_string(cinn_nvcc_cmd_path);
DECLARE_bool(nvrtc_compile_to_cubin);
PD_DECLARE_string(cinn_nvcc_cmd_path);
PD_DECLARE_bool(nvrtc_compile_to_cubin);
PD_DECLARE_bool(cinn_nvrtc_cubin_with_fmad);
namespace cinn {
namespace backends {
......@@ -106,6 +107,9 @@ std::string Compiler::CompileCudaSource(const std::string& code,
}
if (compile_to_cubin_) {
compile_options.push_back("-arch=sm_" + cc);
std::string enable_fmad =
FLAGS_cinn_nvrtc_cubin_with_fmad ? "true" : "false";
compile_options.push_back("--fmad=" + enable_fmad);
} else {
compile_options.push_back("-arch=compute_" + cc);
}
......
......@@ -29,6 +29,7 @@
namespace cinn {
using ast_gen_ius::TensorGroup;
using backends::CodeGenC;
using backends::CodeGenCX86;
using backends::Outputs;
......@@ -39,6 +40,7 @@ using lang::CallExtern;
using lang::CallLowered;
using lang::Compute;
using lang::Lower;
using lang::LowerToAst;
using lang::Placeholder;
using lang::ReduceAll;
using lang::ReduceAny;
......
......@@ -19,10 +19,13 @@ gather_srcs(
arithmatic.cc
cas.cc
union_find.cc
python_interpreter_guard.cc)
python_interpreter_guard.cc
nvgpu_dev_info.cc)
message(STATUS "srcs: ${cinnapi_src}")
cinn_cc_test(test_equation_graph_topo_walker SRCS
equation_graph_topo_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_dfs_walker SRCS dfs_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_dfs_topo_walker SRCS dfs_topo_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_is_reachable_predicator SRCS is_reachable_predicator_test.cc
......
......@@ -21,9 +21,9 @@
#include <string>
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......@@ -126,7 +126,7 @@ GiNaC::ex ExprToGinacConverter::BuildHelper(ir::Expr expr) {
GiNaC::ex ExprToGinacConverter::operator()(Expr expr) {
// TODO(Superjomn) Replace this with common::IsPureMath(
auto complex_nodes = CollectIRNodes(expr, [](const Expr* n) {
auto complex_nodes = ir::ir_utils::CollectIRNodes(expr, [](const Expr* n) {
return n->As<Block>() || //
n->As<PolyFor>() || //
n->As<EQ>() || //
......@@ -262,7 +262,7 @@ bool IsPureMath(Expr expr) {
IrNodeTy ::Minus,
});
auto complex_nodes = ir::CollectIRNodes(expr, [&](const Expr* n) {
auto complex_nodes = ir::ir_utils::CollectIRNodes(expr, [&](const Expr* n) {
return !valid_node_tys.count(n->node_type());
});
#ifdef CINN_DEBUG
......
......@@ -20,8 +20,8 @@
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......
......@@ -21,13 +21,12 @@
#include "paddle/cinn/common/arithmatic.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......@@ -1585,7 +1584,7 @@ bool CASasSymbol(Expr expr) {
Expr ConvertCinnToCAS(Expr expr) {
VLOG(7) << "Begin ConvertCinnToCAS " << expr;
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
......@@ -1711,7 +1710,7 @@ Expr ConvertCinnToCAS(Expr expr) {
* simplify the condition ensures correctness, though not sufficient.
*/
Expr ReplaceMinToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
......@@ -1728,10 +1727,10 @@ Expr ReplaceMinToConstant(Expr expr) {
auto min_b = op->b();
if (min_a.is_constant() && !min_b.is_constant()) {
CHECK(min_a->type().is_integer());
*expr = optim::IRCopy(min_a);
*expr = ir::ir_utils::IRCopy(min_a);
} else if (min_b.is_constant() && !min_a.is_constant()) {
CHECK(min_b->type().is_integer());
*expr = optim::IRCopy(min_b);
*expr = ir::ir_utils::IRCopy(min_b);
}
}
};
......@@ -1744,7 +1743,7 @@ Expr ReplaceMinToConstant(Expr expr) {
* constant value and 1 inconstant value, return the constant max value.
*/
Expr ReplaceMaxToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
......@@ -1761,10 +1760,10 @@ Expr ReplaceMaxToConstant(Expr expr) {
auto max_b = op->b();
if (max_a.is_constant() && !max_b.is_constant()) {
CHECK(max_a->type().is_integer());
*expr = optim::IRCopy(max_a);
*expr = ir::ir_utils::IRCopy(max_a);
} else if (max_b.is_constant() && !max_a.is_constant()) {
CHECK(max_b->type().is_integer());
*expr = optim::IRCopy(max_b);
*expr = ir::ir_utils::IRCopy(max_b);
}
}
};
......@@ -1774,7 +1773,7 @@ Expr ReplaceMaxToConstant(Expr expr) {
Expr ConvertCasToCinn(Expr expr) {
VLOG(7) << "Begin ConvertCasToCinn : " << expr;
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : ir::IRMutator<Expr*> {
void operator()(Expr* expr) { Visit(expr); }
......@@ -1869,7 +1868,7 @@ bool IsExprCasCompatible(Expr expr) {
return expr->As<Add>() || expr->As<Sub>() || expr->As<Mul>() ||
expr->As<Div>();
};
return ir::CollectIRNodes(expr, teller).empty();
return ir::ir_utils::CollectIRNodes(expr, teller).empty();
}
// Partially divide a by b. e.g. (2x+y)/2 => x + y/2
......
......@@ -20,7 +20,7 @@
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
namespace cinn {
......
......@@ -19,8 +19,8 @@
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment