Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
558
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
342 additions
and
50 deletions
+342
-50
paddle/cinn/backends/compiler.h
paddle/cinn/backends/compiler.h
+28
-9
paddle/cinn/backends/extern_func_emitter.cc
paddle/cinn/backends/extern_func_emitter.cc
+1
-1
paddle/cinn/backends/function_prototype.cc
paddle/cinn/backends/function_prototype.cc
+1
-1
paddle/cinn/backends/ir_schedule_test.cc
paddle/cinn/backends/ir_schedule_test.cc
+267
-5
paddle/cinn/backends/llvm/codegen_llvm.cc
paddle/cinn/backends/llvm/codegen_llvm.cc
+5
-2
paddle/cinn/backends/llvm/codegen_llvm.h
paddle/cinn/backends/llvm/codegen_llvm.h
+1
-1
paddle/cinn/backends/llvm/codegen_x86.cc
paddle/cinn/backends/llvm/codegen_x86.cc
+2
-2
paddle/cinn/backends/llvm/execution_engine.cc
paddle/cinn/backends/llvm/execution_engine.cc
+1
-1
paddle/cinn/backends/llvm/execution_engine_test.cc
paddle/cinn/backends/llvm/execution_engine_test.cc
+1
-1
paddle/cinn/backends/llvm/runtime_symbol_registry.cc
paddle/cinn/backends/llvm/runtime_symbol_registry.cc
+2
-2
paddle/cinn/backends/llvm/simple_jit.cc
paddle/cinn/backends/llvm/simple_jit.cc
+1
-1
paddle/cinn/backends/modular.cc
paddle/cinn/backends/modular.cc
+1
-1
paddle/cinn/backends/nvrtc/nvrtc_util.cc
paddle/cinn/backends/nvrtc/nvrtc_util.cc
+6
-2
paddle/cinn/cinn.h
paddle/cinn/cinn.h
+2
-0
paddle/cinn/common/CMakeLists.txt
paddle/cinn/common/CMakeLists.txt
+4
-1
paddle/cinn/common/arithmatic.cc
paddle/cinn/common/arithmatic.cc
+4
-4
paddle/cinn/common/arithmatic_test.cc
paddle/cinn/common/arithmatic_test.cc
+1
-1
paddle/cinn/common/cas.cc
paddle/cinn/common/cas.cc
+12
-13
paddle/cinn/common/cas.h
paddle/cinn/common/cas.h
+1
-1
paddle/cinn/common/cas_test.cc
paddle/cinn/common/cas_test.cc
+1
-1
No files found.
Too many changes to show.
To preserve performance only
558 of 558+
files are displayed.
Plain diff
Email patch
paddle/cinn/backends/compiler.h
View file @
01a10755
...
...
@@ -43,26 +43,42 @@ namespace backends {
*/
class
CompilationInfoDumper
{
public:
explicit
CompilationInfoDumper
(
const
hlir
::
framework
::
ParallelCompiler
::
CompilationResult
&
info
)
:
info_
(
info
)
{
explicit
CompilationInfoDumper
(
const
hlir
::
framework
::
CompilationResult
&
info
,
const
int
device_id
)
:
info_
(
info
)
,
device_id_
(
device_id
)
{
DumpLoweredFunc
();
DumpSourceCode
();
DumpPtxCode
();
DumpInstruction
();
}
static
void
DumpLoweredFuncByGroupIndex
(
const
ir
::
LoweredFunc
&
lowered_func
,
const
int
gidx
,
const
int
device_id
);
static
void
DumpSourceCodeByGroupIndex
(
const
std
::
string
&
source_code
,
const
int
gidx
,
const
int
device_id
);
static
void
DumpPtxCodeByGroupIndex
(
const
std
::
string
&
source_ptx
,
const
int
gidx
,
const
int
device_id
);
static
void
DumpInstructionByGroupIndex
(
const
std
::
unique_ptr
<
cinn
::
hlir
::
framework
::
Instruction
>&
instr
,
const
int
gidx
,
const
int
device_id
);
private:
void
DumpLoweredFunc
();
void
DumpSourceCode
();
void
DumpPtxCode
();
void
DumpInstruction
();
void
Dump
(
const
std
::
string
&
base_path
,
const
int
idx
,
const
std
::
string
&
file_name
,
const
std
::
string
&
content
);
const
hlir
::
framework
::
ParallelCompiler
::
CompilationResult
&
info_
;
static
void
Dump
(
const
std
::
string
&
base_path
,
const
int
idx
,
const
int
device_id
,
const
std
::
string
&
file_name
,
const
std
::
string
&
content
);
const
hlir
::
framework
::
CompilationResult
&
info_
;
const
int
device_id_
;
};
class
SourceCodePrint
{
...
...
@@ -105,6 +121,8 @@ class Compiler final {
*/
void
*
Lookup
(
absl
::
string_view
fn_name
);
std
::
vector
<
void
*>
GetFnPtr
()
const
{
return
fn_ptr_
;
}
private:
void
CompileCudaModule
(
const
ir
::
Module
&
module
,
const
std
::
string
&
code
=
""
);
...
...
@@ -120,6 +138,7 @@ class Compiler final {
Target
target_
;
std
::
unique_ptr
<
ExecutionEngine
>
engine_
;
std
::
vector
<
void
*>
fn_ptr_
;
#ifdef CINN_WITH_CUDA
std
::
unique_ptr
<
runtime
::
cuda
::
CUDAModule
>
cuda_module_
;
#endif
...
...
paddle/cinn/backends/extern_func_emitter.cc
View file @
01a10755
...
...
@@ -27,7 +27,7 @@
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
DECLARE_bool
(
verbose_function_register
);
PD_
DECLARE_bool
(
verbose_function_register
);
namespace
cinn
{
namespace
backends
{
...
...
paddle/cinn/backends/function_prototype.cc
View file @
01a10755
...
...
@@ -21,7 +21,7 @@
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool
(
verbose_function_register
);
PD_
DECLARE_bool
(
verbose_function_register
);
namespace
cinn
{
namespace
backends
{
...
...
paddle/cinn/backends/ir_schedule_test.cc
View file @
01a10755
...
...
@@ -24,8 +24,8 @@
#include "paddle/cinn/backends/codegen_c_x86.h"
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_error.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/remove_schedule_block.h"
...
...
@@ -690,6 +690,7 @@ void test_unroll(void* _args, int32_t num_args)
ASSERT_EQ
(
utils
::
Trim
(
target_code
),
utils
::
Trim
(
source_code
));
}
#ifdef CINN_WITH_CUDA
TEST
(
IrSchedule
,
bind
)
{
Context
::
Global
().
ResetNameId
();
Expr
M
(
32
);
...
...
@@ -733,6 +734,7 @@ function test_bind (_A, _B)
}
)ROC"
));
}
#endif
TEST
(
IrSchedule
,
simple_compute_at
)
{
Context
::
Global
().
ResetNameId
();
...
...
@@ -794,10 +796,8 @@ void test_simple_compute_at(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
};
};
};
...
...
@@ -869,10 +869,8 @@ void test_compute_at0(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
};
};
};
...
...
@@ -2314,6 +2312,270 @@ void test_rfactor(void* _args, int32_t num_args)
ASSERT_EQ
(
utils
::
Trim
(
target_code
),
utils
::
Trim
(
source_code
));
}
TEST
(
IrSchedule
,
factorize_reduction
)
{
Context
::
Global
().
ResetNameId
();
Expr
M
(
3
);
Expr
N
(
4
);
Expr
K
(
5
);
Target
target
=
common
::
DefaultHostTarget
();
Placeholder
<
float
>
A
(
"A"
,
{
M
,
N
,
K
});
Var
j
(
4
,
"j0"
);
Var
k
(
5
,
"k0"
);
auto
B
=
Compute
(
{
M
},
[
&
](
Var
i
)
{
return
lang
::
ReduceSum
(
A
(
i
,
j
,
k
),
{
j
,
k
});
},
"B"
);
auto
stages
=
CreateStages
({
A
,
B
});
auto
func
=
cinn
::
lang
::
LowerVec
(
"test_factorize_reduction"
,
stages
,
{
A
,
B
},
{},
{},
nullptr
,
target
,
true
);
CHECK
(
!
func
.
empty
());
auto
ast_expr
=
func
[
0
]
->
body
;
std
::
vector
<
Expr
>
vec_ast
{
ast_expr
};
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
auto
loops
=
ir_sch
.
GetLoops
(
"B"
);
CHECK_EQ
(
loops
.
size
(),
3U
);
auto
new_rf_tensor
=
ir_sch
.
FactorizeReduction
(
loops
[
1
],
0
);
auto
*
new_rf_tensor_ref
=
new_rf_tensor
.
As
<
ir
::
_Tensor_
>
();
CHECK
(
new_rf_tensor_ref
);
CHECK
(
new_rf_tensor_ref
->
buffer
.
defined
());
func
[
0
]
->
temp_bufs
.
push_back
(
new_rf_tensor_ref
->
buffer
);
func
[
0
]
->
PrepareBufferCastExprs
();
std
::
string
origin
=
utils
::
GetStreamCnt
(
func
[
0
]);
LOG
(
INFO
)
<<
origin
;
EXPECT_EQ
(
origin
,
utils
::
Trim
(
R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[vj0, i0_0] = 0.00000000f
}
serial for (k0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, i2 = axis.bind(j0, i, k0)
B_rf[vj0, i0_0] = (B_rf[vj0, i0_0] + A[i0_0, vj0, i2])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[vj0, i0_0])
}
}
}
}
}
}
)ROC"
));
}
TEST
(
IrSchedule
,
factorize_reduction1
)
{
Context
::
Global
().
ResetNameId
();
Expr
M
(
3
);
Expr
N
(
4
);
Expr
K
(
5
);
Target
target
=
common
::
DefaultHostTarget
();
Placeholder
<
float
>
A
(
"A"
,
{
M
,
N
,
K
});
Var
j
(
4
,
"j0"
);
Var
k
(
5
,
"k0"
);
auto
B
=
Compute
(
{
M
},
[
&
](
Var
i
)
{
return
lang
::
ReduceSum
(
A
(
i
,
j
,
k
),
{
j
,
k
});
},
"B"
);
auto
stages
=
CreateStages
({
A
,
B
});
auto
func
=
cinn
::
lang
::
LowerVec
(
"test_factorize_reduction"
,
stages
,
{
A
,
B
},
{},
{},
nullptr
,
target
,
true
);
CHECK
(
!
func
.
empty
());
auto
ast_expr
=
func
[
0
]
->
body
;
std
::
vector
<
Expr
>
vec_ast
{
ast_expr
};
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
auto
loops
=
ir_sch
.
GetLoops
(
"B"
);
CHECK_EQ
(
loops
.
size
(),
3U
);
auto
new_rf_tensor
=
ir_sch
.
FactorizeReduction
(
loops
[
1
],
1
);
auto
*
new_rf_tensor_ref
=
new_rf_tensor
.
As
<
ir
::
_Tensor_
>
();
CHECK
(
new_rf_tensor_ref
);
CHECK
(
new_rf_tensor_ref
->
buffer
.
defined
());
func
[
0
]
->
temp_bufs
.
push_back
(
new_rf_tensor_ref
->
buffer
);
func
[
0
]
->
PrepareBufferCastExprs
();
std
::
string
origin
=
utils
::
GetStreamCnt
(
func
[
0
]);
LOG
(
INFO
)
<<
origin
;
EXPECT_EQ
(
origin
,
utils
::
Trim
(
R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[i0_0, vj0] = 0.00000000f
}
serial for (k0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, i2 = axis.bind(j0, i, k0)
B_rf[i0_0, vj0] = (B_rf[i0_0, vj0] + A[i0_0, vj0, i2])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[i0_0, vj0])
}
}
}
}
}
}
)ROC"
));
}
TEST
(
IrSchedule
,
factorize_reduction2
)
{
Context
::
Global
().
ResetNameId
();
Expr
M
(
3
);
Expr
N
(
4
);
Expr
K
(
5
);
Target
target
=
common
::
DefaultHostTarget
();
Placeholder
<
float
>
A
(
"A"
,
{
M
,
N
*
K
});
Var
j
(
4
*
5
,
"j0"
);
auto
B
=
Compute
(
{
M
},
[
&
](
Var
i
)
{
return
lang
::
ReduceSum
(
A
(
i
,
j
),
{
j
});
},
"B"
);
auto
stages
=
CreateStages
({
A
,
B
});
auto
func
=
cinn
::
lang
::
LowerVec
(
"test_factorize_reduction"
,
stages
,
{
A
,
B
},
{},
{},
nullptr
,
target
,
true
);
CHECK
(
!
func
.
empty
());
auto
ast_expr
=
func
[
0
]
->
body
;
std
::
vector
<
Expr
>
vec_ast
{
ast_expr
};
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
auto
loops
=
ir_sch
.
GetLoops
(
"B"
);
CHECK_EQ
(
loops
.
size
(),
2U
);
auto
splited_loops
=
ir_sch
.
Split
(
loops
[
1
],
{
4
,
5
});
CHECK_EQ
(
splited_loops
.
size
(),
2U
);
auto
new_rf_tensor
=
ir_sch
.
FactorizeReduction
(
splited_loops
[
0
],
1
);
auto
*
new_rf_tensor_ref
=
new_rf_tensor
.
As
<
ir
::
_Tensor_
>
();
CHECK
(
new_rf_tensor_ref
);
CHECK
(
new_rf_tensor_ref
->
buffer
.
defined
());
func
[
0
]
->
temp_bufs
.
push_back
(
new_rf_tensor_ref
->
buffer
);
func
[
0
]
->
PrepareBufferCastExprs
();
std
::
string
origin
=
utils
::
GetStreamCnt
(
func
[
0
]);
LOG
(
INFO
)
<<
origin
;
EXPECT_EQ
(
origin
,
utils
::
Trim
(
R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[i0_0, vj0] = 0.00000000f
}
serial for (j0_0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, vj0_0 = axis.bind(j0, i, j0_0)
B_rf[i0_0, vj0] = (B_rf[i0_0, vj0] + A[i0_0, ((5 * vj0) + vj0_0)])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[i0_0, vj0])
}
}
}
}
}
}
)ROC"
));
}
TEST
(
IrSchedule
,
compute_inline1
)
{
Context
::
Global
().
ResetNameId
();
Expr
M
(
32
);
...
...
paddle/cinn/backends/llvm/codegen_llvm.cc
View file @
01a10755
...
...
@@ -43,8 +43,8 @@
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_verify.h"
#include "paddle/cinn/optim/var_mod_simplify.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
...
...
@@ -747,6 +747,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlock *) {
llvm
::
Value
*
CodeGenLLVM
::
Visit
(
const
ir
::
ScheduleBlockRealize
*
)
{
CINN_NOT_IMPLEMENTED
return
nullptr
;
}
llvm
::
Value
*
CodeGenLLVM
::
Visit
(
const
ir
::
_Dim_
*
)
{
CINN_NOT_IMPLEMENTED
return
nullptr
;
}
llvm
::
Value
*
CodeGenLLVM
::
Visit
(
const
ir
::
Call
*
op
)
{
if
(
op
->
name
==
runtime
::
intrinsic
::
debug_log_repr
)
{
...
...
@@ -790,7 +793,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
llvm
::
Value
*
CodeGenLLVM
::
Visit
(
const
ir
::
_Module_
*
op
)
{
{
Expr
body_to_verify
(
&
Reference
(
op
));
ir
::
IrVerify
(
body_to_verify
);
ir
::
ir_utils
::
IrVerify
(
body_to_verify
);
}
for
(
auto
&
fn
:
op
->
functions
)
{
...
...
paddle/cinn/backends/llvm/codegen_llvm.h
View file @
01a10755
...
...
@@ -32,9 +32,9 @@
#include "paddle/cinn/backends/llvm/ir_builder_mixin.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/ir/intrinsic_ops.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace
cinn
{
namespace
backends
{
...
...
paddle/cinn/backends/llvm/codegen_x86.cc
View file @
01a10755
...
...
@@ -28,7 +28,7 @@
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/
optim/collect_undefined_vars
.h"
#include "paddle/cinn/
ir/utils/ir_nodes_collector
.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace
cinn
::
backends
{
...
...
@@ -98,7 +98,7 @@ void CodeGenX86::CreateParallelLaunch(Expr body, int num_task) {
llvm
::
Function
::
PrivateLinkage
,
"__parallel_lambda"
,
m_
);
std
::
vector
<
std
::
string
>
vars
=
optim
::
CollectUndefinedVars
(
&
body
);
std
::
vector
<
std
::
string
>
vars
=
ir
::
ir_utils
::
CollectUndefinedVars
(
&
body
);
uint64_t
nbytes
;
auto
*
data
=
PackVars
(
vars
,
&
nbytes
);
...
...
paddle/cinn/backends/llvm/execution_engine.cc
View file @
01a10755
...
...
@@ -61,7 +61,7 @@
#include "paddle/cinn/backends/llvm/llvm_optimizer.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/profiler.h"
...
...
paddle/cinn/backends/llvm/execution_engine_test.cc
View file @
01a10755
...
...
@@ -41,8 +41,8 @@
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
...
...
paddle/cinn/backends/llvm/runtime_symbol_registry.cc
View file @
01a10755
...
...
@@ -19,10 +19,10 @@
#include <iostream>
#include "gflags/gflags_declare.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/utils/flags.h"
DECLARE_bool
(
verbose_function_register
);
PD_
DECLARE_bool
(
verbose_function_register
);
namespace
cinn
{
namespace
backends
{
...
...
paddle/cinn/backends/llvm/simple_jit.cc
View file @
01a10755
...
...
@@ -37,7 +37,7 @@
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace
cinn
{
...
...
paddle/cinn/backends/modular.cc
View file @
01a10755
...
...
@@ -14,7 +14,7 @@
#include "paddle/cinn/backends/modular.h"
#include "paddle/cinn/ir/
utils/
ir_visitor.h"
#include "paddle/cinn/ir/ir_visitor.h"
namespace
cinn
{
namespace
backends
{
...
...
paddle/cinn/backends/nvrtc/nvrtc_util.cc
View file @
01a10755
...
...
@@ -30,8 +30,9 @@
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
DECLARE_string
(
cinn_nvcc_cmd_path
);
DECLARE_bool
(
nvrtc_compile_to_cubin
);
PD_DECLARE_string
(
cinn_nvcc_cmd_path
);
PD_DECLARE_bool
(
nvrtc_compile_to_cubin
);
PD_DECLARE_bool
(
cinn_nvrtc_cubin_with_fmad
);
namespace
cinn
{
namespace
backends
{
...
...
@@ -106,6 +107,9 @@ std::string Compiler::CompileCudaSource(const std::string& code,
}
if
(
compile_to_cubin_
)
{
compile_options
.
push_back
(
"-arch=sm_"
+
cc
);
std
::
string
enable_fmad
=
FLAGS_cinn_nvrtc_cubin_with_fmad
?
"true"
:
"false"
;
compile_options
.
push_back
(
"--fmad="
+
enable_fmad
);
}
else
{
compile_options
.
push_back
(
"-arch=compute_"
+
cc
);
}
...
...
paddle/cinn/cinn.h
View file @
01a10755
...
...
@@ -29,6 +29,7 @@
namespace
cinn
{
using
ast_gen_ius
::
TensorGroup
;
using
backends
::
CodeGenC
;
using
backends
::
CodeGenCX86
;
using
backends
::
Outputs
;
...
...
@@ -39,6 +40,7 @@ using lang::CallExtern;
using
lang
::
CallLowered
;
using
lang
::
Compute
;
using
lang
::
Lower
;
using
lang
::
LowerToAst
;
using
lang
::
Placeholder
;
using
lang
::
ReduceAll
;
using
lang
::
ReduceAny
;
...
...
paddle/cinn/common/CMakeLists.txt
View file @
01a10755
...
...
@@ -19,10 +19,13 @@ gather_srcs(
arithmatic.cc
cas.cc
union_find.cc
python_interpreter_guard.cc
)
python_interpreter_guard.cc
nvgpu_dev_info.cc
)
message
(
STATUS
"srcs:
${
cinnapi_src
}
"
)
cinn_cc_test
(
test_equation_graph_topo_walker SRCS
equation_graph_topo_walker_test.cc DEPS gtest glog
)
cinn_cc_test
(
test_dfs_walker SRCS dfs_walker_test.cc DEPS gtest glog
)
cinn_cc_test
(
test_dfs_topo_walker SRCS dfs_topo_walker_test.cc DEPS gtest glog
)
cinn_cc_test
(
test_is_reachable_predicator SRCS is_reachable_predicator_test.cc
...
...
paddle/cinn/common/arithmatic.cc
View file @
01a10755
...
...
@@ -21,9 +21,9 @@
#include <string>
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
@@ -126,7 +126,7 @@ GiNaC::ex ExprToGinacConverter::BuildHelper(ir::Expr expr) {
GiNaC
::
ex
ExprToGinacConverter
::
operator
()(
Expr
expr
)
{
// TODO(Superjomn) Replace this with common::IsPureMath(
auto
complex_nodes
=
CollectIRNodes
(
expr
,
[](
const
Expr
*
n
)
{
auto
complex_nodes
=
ir
::
ir_utils
::
CollectIRNodes
(
expr
,
[](
const
Expr
*
n
)
{
return
n
->
As
<
Block
>
()
||
//
n
->
As
<
PolyFor
>
()
||
//
n
->
As
<
EQ
>
()
||
//
...
...
@@ -262,7 +262,7 @@ bool IsPureMath(Expr expr) {
IrNodeTy
::
Minus
,
});
auto
complex_nodes
=
ir
::
CollectIRNodes
(
expr
,
[
&
](
const
Expr
*
n
)
{
auto
complex_nodes
=
ir
::
ir_utils
::
CollectIRNodes
(
expr
,
[
&
](
const
Expr
*
n
)
{
return
!
valid_node_tys
.
count
(
n
->
node_type
());
});
#ifdef CINN_DEBUG
...
...
paddle/cinn/common/arithmatic_test.cc
View file @
01a10755
...
...
@@ -20,8 +20,8 @@
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
paddle/cinn/common/cas.cc
View file @
01a10755
...
...
@@ -21,13 +21,12 @@
#include "paddle/cinn/common/arithmatic.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
@@ -1585,7 +1584,7 @@ bool CASasSymbol(Expr expr) {
Expr
ConvertCinnToCAS
(
Expr
expr
)
{
VLOG
(
7
)
<<
"Begin ConvertCinnToCAS "
<<
expr
;
Expr
copied
=
optim
::
IRCopy
(
expr
);
Expr
copied
=
ir
::
ir_utils
::
IRCopy
(
expr
);
struct
Mutator
:
public
ir
::
IRMutator
<
ir
::
Expr
*>
{
void
operator
()(
Expr
*
expr
)
{
Visit
(
expr
);
}
void
Visit
(
Expr
*
expr
)
{
ir
::
IRMutator
<>::
Visit
(
expr
,
expr
);
}
...
...
@@ -1711,7 +1710,7 @@ Expr ConvertCinnToCAS(Expr expr) {
* simplify the condition ensures correctness, though not sufficient.
*/
Expr
ReplaceMinToConstant
(
Expr
expr
)
{
Expr
copied
=
optim
::
IRCopy
(
expr
);
Expr
copied
=
ir
::
ir_utils
::
IRCopy
(
expr
);
struct
Mutator
:
public
ir
::
IRMutator
<
ir
::
Expr
*>
{
void
operator
()(
Expr
*
expr
)
{
Visit
(
expr
);
}
void
Visit
(
Expr
*
expr
)
{
ir
::
IRMutator
<>::
Visit
(
expr
,
expr
);
}
...
...
@@ -1728,10 +1727,10 @@ Expr ReplaceMinToConstant(Expr expr) {
auto
min_b
=
op
->
b
();
if
(
min_a
.
is_constant
()
&&
!
min_b
.
is_constant
())
{
CHECK
(
min_a
->
type
().
is_integer
());
*
expr
=
optim
::
IRCopy
(
min_a
);
*
expr
=
ir
::
ir_utils
::
IRCopy
(
min_a
);
}
else
if
(
min_b
.
is_constant
()
&&
!
min_a
.
is_constant
())
{
CHECK
(
min_b
->
type
().
is_integer
());
*
expr
=
optim
::
IRCopy
(
min_b
);
*
expr
=
ir
::
ir_utils
::
IRCopy
(
min_b
);
}
}
};
...
...
@@ -1744,7 +1743,7 @@ Expr ReplaceMinToConstant(Expr expr) {
* constant value and 1 inconstant value, return the constant max value.
*/
Expr
ReplaceMaxToConstant
(
Expr
expr
)
{
Expr
copied
=
optim
::
IRCopy
(
expr
);
Expr
copied
=
ir
::
ir_utils
::
IRCopy
(
expr
);
struct
Mutator
:
public
ir
::
IRMutator
<
ir
::
Expr
*>
{
void
operator
()(
Expr
*
expr
)
{
Visit
(
expr
);
}
void
Visit
(
Expr
*
expr
)
{
ir
::
IRMutator
<>::
Visit
(
expr
,
expr
);
}
...
...
@@ -1761,10 +1760,10 @@ Expr ReplaceMaxToConstant(Expr expr) {
auto
max_b
=
op
->
b
();
if
(
max_a
.
is_constant
()
&&
!
max_b
.
is_constant
())
{
CHECK
(
max_a
->
type
().
is_integer
());
*
expr
=
optim
::
IRCopy
(
max_a
);
*
expr
=
ir
::
ir_utils
::
IRCopy
(
max_a
);
}
else
if
(
max_b
.
is_constant
()
&&
!
max_a
.
is_constant
())
{
CHECK
(
max_b
->
type
().
is_integer
());
*
expr
=
optim
::
IRCopy
(
max_b
);
*
expr
=
ir
::
ir_utils
::
IRCopy
(
max_b
);
}
}
};
...
...
@@ -1774,7 +1773,7 @@ Expr ReplaceMaxToConstant(Expr expr) {
Expr
ConvertCasToCinn
(
Expr
expr
)
{
VLOG
(
7
)
<<
"Begin ConvertCasToCinn : "
<<
expr
;
Expr
copied
=
optim
::
IRCopy
(
expr
);
Expr
copied
=
ir
::
ir_utils
::
IRCopy
(
expr
);
struct
Mutator
:
ir
::
IRMutator
<
Expr
*>
{
void
operator
()(
Expr
*
expr
)
{
Visit
(
expr
);
}
...
...
@@ -1869,7 +1868,7 @@ bool IsExprCasCompatible(Expr expr) {
return
expr
->
As
<
Add
>
()
||
expr
->
As
<
Sub
>
()
||
expr
->
As
<
Mul
>
()
||
expr
->
As
<
Div
>
();
};
return
ir
::
CollectIRNodes
(
expr
,
teller
).
empty
();
return
ir
::
ir_utils
::
CollectIRNodes
(
expr
,
teller
).
empty
();
}
// Partially divide a by b. e.g. (2x+y)/2 => x + y/2
...
...
paddle/cinn/common/cas.h
View file @
01a10755
...
...
@@ -20,7 +20,7 @@
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
namespace
cinn
{
...
...
paddle/cinn/common/cas_test.cc
View file @
01a10755
...
...
@@ -19,8 +19,8 @@
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment