Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
565
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
494 additions
and
122 deletions
+494
-122
paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc
...auto_schedule/search_strategy/evolutionary_search_test.cc
+4
-3
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc
...dule/search_strategy/mutate_rule/mutate_tile_size_test.cc
+12
-14
paddle/cinn/auto_schedule/task/task_optimizer.cc
paddle/cinn/auto_schedule/task/task_optimizer.cc
+3
-3
paddle/cinn/auto_schedule/task/task_registry.h
paddle/cinn/auto_schedule/task/task_registry.h
+2
-3
paddle/cinn/auto_schedule/task/task_registry_test.cc
paddle/cinn/auto_schedule/task/task_registry_test.cc
+4
-5
paddle/cinn/auto_schedule/task/tune_task.cc
paddle/cinn/auto_schedule/task/tune_task.cc
+1
-1
paddle/cinn/auto_schedule/task/tune_task.h
paddle/cinn/auto_schedule/task/tune_task.h
+5
-4
paddle/cinn/auto_schedule/task/tune_task_test.cc
paddle/cinn/auto_schedule/task/tune_task_test.cc
+7
-4
paddle/cinn/auto_schedule/tests/performance_comparison_test.cc
...e/cinn/auto_schedule/tests/performance_comparison_test.cc
+28
-26
paddle/cinn/backends/codegen_c.cc
paddle/cinn/backends/codegen_c.cc
+6
-24
paddle/cinn/backends/codegen_c.h
paddle/cinn/backends/codegen_c.h
+2
-3
paddle/cinn/backends/codegen_c_test.cc
paddle/cinn/backends/codegen_c_test.cc
+6
-3
paddle/cinn/backends/codegen_cuda_dev.cc
paddle/cinn/backends/codegen_cuda_dev.cc
+3
-4
paddle/cinn/backends/codegen_cuda_dev.h
paddle/cinn/backends/codegen_cuda_dev.h
+1
-1
paddle/cinn/backends/codegen_cuda_generate_test.cc
paddle/cinn/backends/codegen_cuda_generate_test.cc
+1
-1
paddle/cinn/backends/codegen_cuda_host.cc
paddle/cinn/backends/codegen_cuda_host.cc
+154
-0
paddle/cinn/backends/codegen_cuda_host.h
paddle/cinn/backends/codegen_cuda_host.h
+13
-0
paddle/cinn/backends/codegen_cuda_util.cc
paddle/cinn/backends/codegen_cuda_util.cc
+91
-1
paddle/cinn/backends/codegen_cuda_util.h
paddle/cinn/backends/codegen_cuda_util.h
+54
-4
paddle/cinn/backends/compiler.cc
paddle/cinn/backends/compiler.cc
+97
-18
No files found.
Too many changes to show.
To preserve performance only
565 of 565+
files are displayed.
Plain diff
Email patch
paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc
View file @
01a10755
...
...
@@ -27,6 +27,7 @@
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "test/cpp/cinn/program_builder.h"
...
...
@@ -44,11 +45,11 @@ std::vector<TuneTask> CreateTasks(const frontend::Program& program,
"inferdtype"
);
const
auto
&
shape_dict
=
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
auto
op_lowerer
=
std
::
make_unique
<
hlir
::
framework
::
OpLowerer
>
(
dtype_dict
,
shape_dict
,
target
);
auto
op_lowerer
=
hlir
::
framework
::
CreateOpLowerer
(
dtype_dict
,
shape_dict
,
target
);
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
for
(
auto
i
=
0
;
i
<
tasks
.
size
();
++
i
)
{
tasks
[
i
].
Initialize
(
shape_dict
,
dtype_dict
,
op_lowerer
.
get
()
);
tasks
[
i
].
Initialize
(
shape_dict
,
dtype_dict
,
&
op_lowerer
);
task_registry
->
Regist
(
tasks
[
i
].
serialized_key
,
ir
::
ModuleExpr
(
tasks
[
i
].
GetLoweredFuncBodyExprs
()));
}
...
...
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc
View file @
01a10755
...
...
@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
...
...
@@ -46,16 +47,13 @@ TEST(MutateTileSize, Basic) {
[
&
](
Var
i
,
Var
j
)
{
return
ReduceSum
(
A
(
i
,
k
)
*
B
(
k
,
j
),
{
k
});
},
"C"
);
poly
::
StageMap
stages
=
CreateStages
({
A
,
B
,
C
});
ast_gen_ius
::
TensorGroup
tensor_group
({
A
,
B
,
C
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"TestMutateTileSize_Basic"
,
stages
,
lang
::
Lower
ToAst
Vec
(
"TestMutateTileSize_Basic"
,
{
A
,
B
,
C
},
{},
{},
nullptr
,
target
,
true
);
&
tensor_group
,
target
);
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Original Expr: "
;
...
...
@@ -65,7 +63,7 @@ TEST(MutateTileSize, Basic) {
// repeated.
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
123
;
ir
::
IRSchedule
ir_schedule
(
module_expr
,
rand_seed
);
ir
::
IRSchedule
new_
ir_schedule
(
ir_schedule
);
ir
::
IRSchedule
p
ir_schedule
(
ir_schedule
);
// apply schedule
auto
loops
=
ir_schedule
.
GetLoops
(
"C"
);
...
...
@@ -76,13 +74,13 @@ TEST(MutateTileSize, Basic) {
MutateTileSize
mutator
;
ir
::
ScheduleDesc
sch_desc
=
mutator
.
Apply
(
ir_schedule
.
GetTraceDesc
(),
&
rand_seed
);
sch_desc
.
Replay
(
&
new_
ir_schedule
,
true
);
sch_desc
.
Replay
(
&
p
ir_schedule
,
true
);
VLOG
(
6
)
<<
"Expr before mutate tile size:
\n
"
<<
ir_schedule
.
GetModule
().
GetExprs
()[
0
];
VLOG
(
6
)
<<
"Expr after mutate tile size:
\n
"
<<
new_
ir_schedule
.
GetModule
().
GetExprs
()[
0
];
<<
p
ir_schedule
.
GetModule
().
GetExprs
()[
0
];
std
::
string
target_
new_
ir
=
R"ROC({
std
::
string
target_
p
ir
=
R"ROC({
ScheduleBlock(root)
{
serial for (i_1, 0, 2)
...
...
@@ -117,7 +115,7 @@ TEST(MutateTileSize, Basic) {
ss
<<
exprs
[
0
];
return
ss
.
str
();
};
ASSERT_EQ
(
get_ir_str
(
&
new_
ir_schedule
),
target_
new_
ir
);
ASSERT_EQ
(
get_ir_str
(
&
p
ir_schedule
),
target_
p
ir
);
std
::
vector
<
int
>
last_tile_factors
=
{
2
,
16
};
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
...
...
paddle/cinn/auto_schedule/task/task_optimizer.cc
View file @
01a10755
...
...
@@ -40,7 +40,7 @@
#include "paddle/cinn/backends/cuda_util.h"
#endif
DECLARE_bool
(
auto_schedule_use_cost_model
);
PD_
DECLARE_bool
(
auto_schedule_use_cost_model
);
namespace
cinn
{
namespace
auto_schedule
{
...
...
@@ -247,7 +247,7 @@ TaskOptimizer::Result TaskOptimizer::OptimizeByEvolution(
auto
&
optimized_funcs
=
result
.
functions
;
auto
&
best_cost
=
result
.
cost
;
// use initial lowered function as default result
optimized_funcs
=
optim
::
IRCopy
(
task_
->
lowered_funcs
);
optimized_funcs
=
ir
::
ir_utils
::
IRCopy
(
task_
->
lowered_funcs
);
if
(
options
.
num_measure_trials
==
0
)
{
// no need to measure and simply return the best searched
std
::
vector
<
MeasureInput
>
measure_candidates
;
...
...
@@ -347,7 +347,7 @@ std::vector<SearchState> TaskOptimizer::SearchOneRound(
CHECK_EQ
(
best_exprs
.
size
(),
task_
->
lowered_funcs
.
size
())
<<
"RuntimeError: Expr size is not equal to LoweredFunc size in "
"TaskOptimizer"
;
auto
init_funcs
=
optim
::
IRCopy
(
task_
->
lowered_funcs
);
auto
init_funcs
=
ir
::
ir_utils
::
IRCopy
(
task_
->
lowered_funcs
);
std
::
vector
<
ir
::
LoweredFunc
>
valid_funcs
;
for
(
size_t
j
=
0
;
j
<
best_exprs
.
size
();
++
j
)
{
auto
updated_f
=
...
...
paddle/cinn/auto_schedule/task/task_registry.h
View file @
01a10755
...
...
@@ -14,14 +14,13 @@
#pragma once
#include <gflags/gflags.h>
#include <mutex>
#include <string>
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/utils/registry.h"
#include "paddle/utils/flags.h"
namespace
cinn
{
...
...
@@ -64,7 +63,7 @@ class InitialTaskRegistry : public Registry<InitialTaskInfo> {
std
::
lock_guard
<
std
::
mutex
>
guard
(
registering_mutex
);
if
(
fmap_
.
count
(
task_key
)
==
0
)
{
InitialTaskInfo
*
task_info
=
new
InitialTaskInfo
(
task_key
,
optim
::
IRCopy
(
module_expr
));
new
InitialTaskInfo
(
task_key
,
ir
::
ir_utils
::
IRCopy
(
module_expr
));
__REGISTER__
(
task_key
,
task_info
);
}
}
...
...
paddle/cinn/auto_schedule/task/task_registry_test.cc
View file @
01a10755
...
...
@@ -28,7 +28,7 @@
#include "paddle/cinn/utils/string.h"
#include "paddle/cinn/utils/type_defs.h"
DECLARE_bool
(
auto_schedule_use_cost_model
);
PD_
DECLARE_bool
(
auto_schedule_use_cost_model
);
namespace
cinn
{
namespace
auto_schedule
{
...
...
@@ -45,11 +45,10 @@ std::vector<TuneTask> CreateTasks(hlir::framework::Graph* graph,
const
auto
&
shape_dict
=
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
std
::
unique_ptr
<
hlir
::
framework
::
OpLowerer
>
op_lowerer
=
std
::
make_unique
<
hlir
::
framework
::
OpLowerer
>
(
dtype_dict
,
shape_dict
,
target
);
auto
op_lowerer
=
hlir
::
framework
::
CreateOpLowerer
(
dtype_dict
,
shape_dict
,
target
);
for
(
TuneTask
&
task
:
tasks
)
{
task
.
Initialize
(
shape_dict
,
dtype_dict
,
op_lowerer
.
get
()
);
task
.
Initialize
(
shape_dict
,
dtype_dict
,
&
op_lowerer
);
VLOG
(
3
)
<<
"Add a task with serialized_key:
\n
"
<<
task
.
serialized_key
;
}
...
...
paddle/cinn/auto_schedule/task/tune_task.cc
View file @
01a10755
...
...
@@ -34,7 +34,7 @@ void TuneTask::Initialize(
const
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>&
shape_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
cinn
::
common
::
Type
>&
dtype_dict
,
hlir
::
framework
::
OpLowerer
*
lower_handler
)
{
hlir
::
framework
::
OpLowerer
<
GroupPtr
>
*
lower_handler
)
{
CHECK
(
lower_handler
!=
nullptr
)
<<
"op_lowerer can't be nullptr"
;
op_lowerer
=
lower_handler
;
...
...
paddle/cinn/auto_schedule/task/tune_task.h
View file @
01a10755
...
...
@@ -34,16 +34,17 @@ namespace cinn {
namespace
auto_schedule
{
class
TuneTask
{
using
GroupPtr
=
hlir
::
framework
::
GroupPtr
;
public:
TuneTask
()
=
default
;
explicit
TuneTask
(
std
::
shared_ptr
<
hlir
::
framework
::
Graph
::
Group
>
group
)
:
subgraph
(
group
)
{}
explicit
TuneTask
(
GroupPtr
group
)
:
subgraph
(
group
)
{}
// Initialize a task
void
Initialize
(
const
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>&
shape_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
cinn
::
common
::
Type
>&
dtype_dict
,
hlir
::
framework
::
OpLowerer
*
lower_handler
);
hlir
::
framework
::
OpLowerer
<
GroupPtr
>
*
lower_handler
);
// Extract bodies in lowered_funcs() and return
std
::
vector
<
ir
::
Expr
>
GetLoweredFuncBodyExprs
()
const
;
...
...
@@ -51,7 +52,7 @@ class TuneTask {
// sub-graph (if an op won't be fused, it will be a Group with size=1).
std
::
shared_ptr
<
hlir
::
framework
::
Graph
::
Group
>
subgraph
;
// Lower handler, Not owned
hlir
::
framework
::
OpLowerer
*
op_lowerer
;
hlir
::
framework
::
OpLowerer
<
GroupPtr
>
*
op_lowerer
;
// target of this task
common
::
Target
target
;
// stores the initial (un-optimized) LoweredFuncs
...
...
paddle/cinn/auto_schedule/task/tune_task_test.cc
View file @
01a10755
...
...
@@ -31,8 +31,8 @@
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
@@ -75,7 +75,8 @@ TEST(TuneTask, GraphToUnoptLoweredFunc_NoPass) {
const
auto
&
dtype_dict
=
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
common
::
Type
>>
(
"inferdtype"
);
OpLowerer
op_lowerer
(
dtype_dict
,
shape_dict
,
target
);
auto
op_lowerer
=
hlir
::
framework
::
CreateOpLowerer
(
dtype_dict
,
shape_dict
,
target
);
std
::
stringstream
ss
;
for
(
TuneTask
&
task
:
tasks
)
{
...
...
@@ -187,7 +188,8 @@ TEST(TuneTask, GraphToUnoptLoweredFunc_ApplyPass) {
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
common
::
Type
>>
(
"inferdtype"
);
OpLowerer
op_lowerer
(
dtype_dict
,
shape_dict
,
target
);
OpLowerer
op_lowerer
(
new
hlir
::
framework
::
OpLowererImpl
(
dtype_dict
,
shape_dict
,
target
));
std
::
stringstream
ss
;
for
(
TuneTask
&
task
:
tasks
)
{
...
...
@@ -291,7 +293,8 @@ TEST(TuneTask, SerializeToString) {
const
auto
&
dtype_dict
=
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
common
::
Type
>>
(
"inferdtype"
);
OpLowerer
op_lowerer
(
dtype_dict
,
shape_dict
,
target
);
OpLowerer
op_lowerer
(
new
hlir
::
framework
::
OpLowererImpl
(
dtype_dict
,
shape_dict
,
target
));
ASSERT_EQ
(
single_tasks
.
size
(),
2UL
);
for
(
auto
&&
task
:
single_tasks
)
{
task
.
Initialize
(
shape_dict
,
dtype_dict
,
&
op_lowerer
);
...
...
paddle/cinn/auto_schedule/tests/performance_comparison_test.cc
View file @
01a10755
...
...
@@ -25,7 +25,9 @@
#include "paddle/cinn/frontend/paddle_model_convertor.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/runtime/flags.h"
...
...
@@ -42,7 +44,7 @@
* parameters for more detail.
*/
DEFINE_string
(
resnet50_model_dir
,
PD_
DEFINE_string
(
resnet50_model_dir
,
"./ResNet50"
,
"the path to paddle model resnet50."
);
// Flags that control which schedule tests will be run.
...
...
@@ -52,15 +54,16 @@ DEFINE_string(resnet50_model_dir,
// auto schedule test, means options = 4 = "100" will run auto schedule test.
// The default value is -1, which means that this flag is disabled to set the
// options
DEFINE_int32
(
evaluate_knobs
,
PD_
DEFINE_int32
(
evaluate_knobs
,
-
1
,
"the options to control which schedule tests will be run."
);
DECLARE_double
(
cinn_infer_model_version
);
PD_
DECLARE_double
(
cinn_infer_model_version
);
namespace
cinn
{
namespace
auto_schedule
{
using
::
cinn
::
hlir
::
framework
::
BuildScope
;
using
::
cinn
::
hlir
::
framework
::
CompilationContext
;
using
::
cinn
::
hlir
::
framework
::
Graph
;
using
::
cinn
::
hlir
::
framework
::
GraphCompiler
;
using
::
cinn
::
hlir
::
framework
::
Instruction
;
...
...
@@ -94,8 +97,8 @@ class PerformanceTester : public ::testing::Test {
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"OpFusionPass"
);
VLOG
(
3
)
<<
"Build "
<<
schedule_name
<<
" program."
;
auto
scope
=
BuildScope
(
target_
,
graph
);
auto
graph_compiler
=
std
::
make_unique
<
GraphCompiler
>
(
target_
,
scope
,
graph
);
CompilationContext
context
(
graph
,
scope
,
target_
);
auto
graph_compiler
=
std
::
make_unique
<
GraphCompiler
>
(
context
);
auto
runtime_program
=
(
this
->*
build_fn
)(
graph
.
get
(),
graph_compiler
.
get
());
if
(
execute
)
{
...
...
@@ -141,28 +144,27 @@ class PerformanceTester : public ::testing::Test {
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
std
::
shared_ptr
<
hlir
::
framework
::
OpLowerer
>
op_lowerer
=
std
::
make_unique
<
hlir
::
framework
::
OpLowerer
>
(
dtype_dict
,
shape_dict
,
target_
);
auto
op_lowerer
=
hlir
::
framework
::
CreateOpLowerer
(
dtype_dict
,
shape_dict
,
target_
);
Graph
Compil
er
::
CompileOptions
compile_options
;
co
mpile_options
.
with_instantiate_variables
=
true
;
Compil
ationContext
&
context
=
graph_compiler
->
GetCompilationContext
()
;
co
ntext
.
with_instantiate_variables
=
true
;
if
(
graph
->
fusion_groups
.
empty
())
{
hlir
::
framework
::
ApplyPasses
(
graph
,
{
"BuildNonFusedGroupsPass"
});
}
co
mpile_options
.
groups
=
graph
->
fusion_groups
;
co
ntext
.
groups
=
graph
->
fusion_groups
;
for
(
auto
group
:
graph
->
fusion_groups
)
{
co
mpile_options
.
lowered_funcs
.
push_back
(
op_lowerer
->
Lower
(
group
,
co
ntext
.
lowered_funcs
.
push_back
(
op_lowerer
.
Lower
(
group
,
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
));
}
VLOG
(
3
)
<<
"===========================No Schedule LoweredFunc "
"Begin==========================="
;
for
(
const
auto
&
funcvec
:
co
mpile_options
.
lowered_funcs
)
{
for
(
const
auto
&
funcvec
:
co
ntext
.
lowered_funcs
)
{
for
(
const
auto
&
func
:
funcvec
)
{
VLOG
(
3
)
<<
func
;
}
...
...
@@ -170,7 +172,7 @@ class PerformanceTester : public ::testing::Test {
VLOG
(
3
)
<<
"===========================No Schedule LoweredFunc "
"End============================="
;
return
graph_compiler
->
Build
(
compile_options
).
runtime_program
;
return
graph_compiler
->
Build
(
)
;
}
std
::
unique_ptr
<
hlir
::
framework
::
Program
>
BuildManualScheduleProgram
(
...
...
@@ -191,13 +193,13 @@ class PerformanceTester : public ::testing::Test {
tuner
->
Initialize
(
tuning_config
,
graph_compiler
);
TuningResult
tuning_result
=
tuner
->
Tune
(
tuning_options
);
Graph
Compil
er
::
CompileOptions
compile_options
;
co
mpile_options
.
with_instantiate_variables
=
true
;
co
mpile_options
.
Apply
(
tuning_result
);
Compil
ationContext
&
context
=
graph_compiler
->
GetCompilationContext
()
;
co
ntext
.
with_instantiate_variables
=
true
;
co
ntext
.
ApplyTuningResult
(
tuning_result
);
VLOG
(
3
)
<<
"===========================Auto Schedule LoweredFunc "
"Begin==========================="
;
for
(
const
auto
&
funcvec
:
co
mpile_options
.
lowered_funcs
)
{
for
(
const
auto
&
funcvec
:
co
ntext
.
lowered_funcs
)
{
for
(
const
auto
&
func
:
funcvec
)
{
VLOG
(
3
)
<<
func
;
}
...
...
@@ -205,7 +207,7 @@ class PerformanceTester : public ::testing::Test {
VLOG
(
3
)
<<
"===========================Auto Schedule LoweredFunc "
"End============================="
;
return
graph_compiler
->
Build
(
compile_options
).
runtime_program
;
return
graph_compiler
->
Build
(
)
;
}
#ifdef CINN_WITH_CUDA
...
...
paddle/cinn/backends/codegen_c.cc
View file @
01a10755
...
...
@@ -23,13 +23,12 @@
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_verify.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/runtime/cpu/thread_backend.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
//! Root of the builtin code.
DECLARE_string
(
cinn_x86_builtin_code_root
);
PD_
DECLARE_string
(
cinn_x86_builtin_code_root
);
namespace
cinn
{
namespace
backends
{
...
...
@@ -39,7 +38,7 @@ using cinn::common::float16;
const
char
*
kCKeywordRestrict
=
"__restrict__"
;
void
CodeGenC
::
Compile
(
const
ir
::
Module
&
module
,
const
Outputs
&
outputs
)
{
ir
::
IrVerify
(
Expr
(
module
));
ir
::
ir_utils
::
IrVerify
(
Expr
(
module
));
if
(
!
outputs
.
c_header_name
.
empty
())
{
auto
source
=
Compile
(
module
,
OutputKind
::
CHeader
);
...
...
@@ -286,31 +285,13 @@ void CodeGenC::Visit(const ir::Select *op) {
void
CodeGenC
::
Visit
(
const
ir
::
IfThenElse
*
op
)
{
str_
+=
"if ("
;
IrPrinter
::
Visit
(
op
->
condition
);
str_
+=
")
{
\n
"
;
str_
+=
") "
;
if
(
!
op
->
true_case
.
As
<
ir
::
Block
>
())
IncIndent
();
DoIndent
();
IrPrinter
::
Visit
(
op
->
true_case
);
if
(
!
op
->
true_case
.
As
<
ir
::
Block
>
())
str_
+=
";"
;
str_
+=
"
\n
"
;
if
(
!
op
->
true_case
.
As
<
ir
::
Block
>
())
DecIndent
();
DoIndent
();
str_
+=
"}"
;
if
(
op
->
false_case
.
defined
())
{
str_
+=
" else {
\n
"
;
if
(
!
op
->
true_case
.
As
<
ir
::
Block
>
())
IncIndent
();
DoIndent
();
str_
+=
" else "
;
IrPrinter
::
Visit
(
op
->
false_case
);
if
(
!
op
->
false_case
.
As
<
ir
::
Block
>
())
str_
+=
";"
;
str_
+=
"
\n
"
;
if
(
!
op
->
true_case
.
As
<
ir
::
Block
>
())
DecIndent
();
DoIndent
();
str_
+=
"}"
;
}
}
void
CodeGenC
::
Visit
(
const
ir
::
Block
*
op
)
{
...
...
@@ -645,7 +626,7 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) {
Expr
func_body
=
ir
::
Block
::
Make
(
new_body
);
optim
::
RemoveNested
Block
(
&
func_body
);
optim
::
Simplify
Block
s
(
&
func_body
);
IrPrinter
::
Visit
(
func_body
);
}
...
...
@@ -766,6 +747,7 @@ void CodeGenC::Visit(const ir::ScheduleBlock *op) { CINN_NOT_IMPLEMENTED }
void
CodeGenC
::
Visit
(
const
ir
::
ScheduleBlockRealize
*
op
)
{
CINN_NOT_IMPLEMENTED
}
void
CodeGenC
::
Visit
(
const
ir
::
_Dim_
*
op
)
{
CINN_NOT_IMPLEMENTED
}
void
CodeGenC
::
Visit
(
const
ir
::
IntrinsicOp
*
op
)
{
switch
(
op
->
getKind
())
{
...
...
paddle/cinn/backends/codegen_c.h
View file @
01a10755
...
...
@@ -14,19 +14,18 @@
#pragma once
#include <gflags/gflags.h>
#include <string>
#include <vector>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/intrinsic_ops.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/packed_func.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
#include "paddle/utils/flags.h"
namespace
cinn
{
...
...
paddle/cinn/backends/codegen_c_test.cc
View file @
01a10755
...
...
@@ -19,6 +19,7 @@
#include <sstream>
#include <tuple>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/module.h"
...
...
@@ -65,8 +66,10 @@ TEST(CodeGenC, module) {
target
.
os
=
Target
::
OS
::
Linux
;
Module
::
Builder
builder
(
"module1"
,
target
);
auto
stages
=
CreateStages
({
A
,
B
,
C
});
auto
func
=
Lower
(
"add1"
,
stages
,
{
A
,
B
,
C
});
ast_gen_ius
::
TensorGroup
tensor_group
({
A
,
B
,
C
});
auto
func
=
lang
::
LowerToAst
(
"add1"
,
{
A
,
B
,
C
},
&
tensor_group
);
LOG
(
INFO
)
<<
"Func to codegen: "
<<
func
<<
std
::
endl
;
builder
.
AddFunction
(
func
);
...
...
@@ -74,7 +77,7 @@ TEST(CodeGenC, module) {
CodeGenC
codegen
(
target
);
codegen
.
SetInlineBuiltinCodes
(
false
);
auto
out
=
codegen
.
Compile
(
builder
.
Build
(),
CodeGenC
::
OutputKind
::
CImpl
);
std
::
cout
<<
"codegen C:"
<<
std
::
endl
<<
out
<<
std
::
endl
;
LOG
(
INFO
)
<<
"codegen C:"
<<
std
::
endl
<<
out
<<
std
::
endl
;
std
::
string
target_str
=
R"ROC(
#include <cinn_runtime.h>
...
...
paddle/cinn/backends/codegen_cuda_dev.cc
View file @
01a10755
...
...
@@ -24,7 +24,6 @@
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_verify.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/remove_nested_block.h"
namespace
cinn
{
namespace
backends
{
...
...
@@ -57,7 +56,7 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, bool for_nvrtc) {
void
CodeGenCUDA_Dev
::
Compile
(
const
ir
::
Module
&
module
,
const
Outputs
&
outputs
)
{
ir
::
IrVerify
(
Expr
(
module
));
ir
::
ir_utils
::
IrVerify
(
Expr
(
module
));
CodeGenC
::
inline_builtin_codes_
=
false
;
if
(
!
outputs
.
c_header_name
.
empty
())
{
...
...
@@ -91,7 +90,7 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(
temp_buffers
.
end
());
// prepare temp buffer alias
std
::
vector
<
Expr
>
buffer_alias
;
auto
tensors
=
ir
::
CollectIRNodes
(
op
->
body
,
[
&
](
const
Expr
*
x
)
{
auto
tensors
=
ir
::
ir_utils
::
CollectIRNodes
(
op
->
body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
x
->
as_tensor
()
->
buffer
.
defined
()
&&
temp_buffer_set
.
count
(
x
->
as_tensor
()
->
buffer
);
});
...
...
@@ -141,7 +140,7 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
Expr
func_body
=
ir
::
Block
::
Make
(
new_body
);
optim
::
RemoveNested
Block
(
&
func_body
);
optim
::
Simplify
Block
s
(
&
func_body
);
// Make sure that the function's body is wrapped by a block
if
(
!
func_body
.
As
<
ir
::
Block
>
())
{
func_body
=
ir
::
Block
::
Make
({
func_body
});
...
...
paddle/cinn/backends/codegen_cuda_dev.h
View file @
01a10755
...
...
@@ -20,9 +20,9 @@
#include "paddle/cinn/backends/codegen_c.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/packed_func.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
...
...
paddle/cinn/backends/codegen_cuda_generate_test.cc
View file @
01a10755
...
...
@@ -30,8 +30,8 @@
#include "paddle/cinn/common/test_helper.h"
#include "paddle/cinn/hlir/pe/nn.h"
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/utils/timer.h"
...
...
paddle/cinn/backends/codegen_cuda_host.cc
View file @
01a10755
...
...
@@ -182,5 +182,159 @@ llvm::Value* CodeGenCUDA_Host::LowerGPUKernelLauncher(
return
function
;
}
llvm
::
Value
*
CodeGenCUDA_Host
::
LowerHostFunc
(
const
ir
::
_LoweredFunc_
*
func
)
{
// Create the function
// @{
auto
*
function_type
=
GenFunctionTypeFromCinnFunction
(
func
,
true
);
f_
=
llvm
::
Function
::
Create
(
function_type
,
llvm
::
Function
::
ExternalLinkage
,
func
->
name
,
m_
);
f_
->
setCallingConv
(
llvm
::
CallingConv
::
C
);
f_
->
setHasUWTable
();
std
::
vector
<
llvm
::
Value
*>
ll_function_args
;
std
::
transform
(
f_
->
arg_begin
(),
f_
->
arg_end
(),
std
::
back_inserter
(
ll_function_args
),
[](
auto
&
arg
)
{
return
std
::
addressof
(
arg
);
});
// @}
llvm
::
BasicBlock
*
entry
=
llvm
::
BasicBlock
::
Create
(
/*Context=*/
b_
->
getContext
(),
/*Name=*/
"entry"
,
/*Parent=*/
f_
,
/*InsertBefore=*/
nullptr
);
b_
->
SetInsertPoint
(
entry
);
CodeGenLLVM
::
Visit
(
&
func
->
body
);
RetVoid
();
return
f_
;
}
llvm
::
Value
*
CodeGenCUDA_Host
::
LowerCUDAKernelCall
(
const
ir
::
Call
*
call_ir
)
{
std
::
vector
<
llvm
::
Value
*>
ll_function_args
;
std
::
transform
(
f_
->
arg_begin
(),
f_
->
arg_end
(),
std
::
back_inserter
(
ll_function_args
),
[](
auto
&
arg
)
{
return
std
::
addressof
(
arg
);
});
auto
*
kernel_args
=
ll_function_args
[
0
];
auto
*
kernel_args_count
=
ll_function_args
[
1
];
llvm
::
Value
*
kernel_stream
=
nullptr
;
if
(
ll_function_args
.
size
()
==
3
)
{
kernel_stream
=
ll_function_args
[
2
];
CHECK_EQ
(
kernel_stream
->
getType
(),
ll_void_p_ty
());
// void* stream
}
CHECK_EQ
(
kernel_args
->
getType
(),
ll_void_p_ty
());
// void* args
CHECK_EQ
(
kernel_args_count
->
getType
(),
ll_int32_ty
());
// int32
std
::
unordered_map
<
std
::
string
,
llvm
::
Value
*>
global_args
=
{
{
KERNEL_ARGS
,
kernel_args
},
{
KERNEL_ARGS_NUM
,
kernel_args_count
},
{
KERNEL_STREAM
,
kernel_stream
}};
auto
ret_type
=
CinnTypeToLLVMType
(
Void
(),
m_
);
std
::
vector
<
llvm
::
Type
*>
args_type
;
for
(
auto
r_arg
:
call_ir
->
read_args
)
{
if
(
r_arg
.
is_var
())
{
if
(
r_arg
.
as_var
()
->
type
().
is_cpp_handle
()
||
r_arg
.
as_var
()
->
type
().
is_string
())
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
void
*>
(),
m_
));
}
else
if
(
r_arg
.
as_var
()
->
type
().
is_int
(
32
))
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
int32_t
>
(),
m_
));
}
else
{
CINN_NOT_IMPLEMENTED
;
}
}
else
{
if
(
r_arg
.
type
().
is_bool
())
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
bool
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_uint
(
8
))
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
uint8_t
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_uint
(
16
))
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
uint16_t
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_uint
(
32
))
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
uint32_t
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_uint
(
64
))
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
uint64_t
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_int
(
8
))
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
int8_t
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_int
(
16
))
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
int16_t
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_int
(
32
))
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
int32_t
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_int
(
64
))
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
int64_t
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_float
(
32
))
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
float
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_float
(
64
))
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
double
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_bfloat16
())
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
bfloat16
>
(),
m_
));
}
else
if
(
r_arg
.
type
().
is_float16
())
{
args_type
.
push_back
(
CinnTypeToLLVMType
(
type_of
<
float16
>
(),
m_
));
}
else
{
CINN_NOT_IMPLEMENTED
;
}
}
}
auto
func_type
=
llvm
::
FunctionType
::
get
(
ret_type
,
args_type
,
false
);
auto
call_func
=
m_
->
getOrInsertFunction
(
call_ir
->
name
,
func_type
);
std
::
vector
<
llvm
::
Value
*>
call_args
;
for
(
auto
&
r_arg
:
call_ir
->
read_args
)
{
if
(
r_arg
.
is_var
())
{
if
(
r_arg
.
as_var
()
->
type
().
is_string
())
{
auto
kvalue
=
m_
->
getOrInsertGlobal
(
r_arg
.
as_var
()
->
name
+
"_ptr_"
,
b_
->
getInt8PtrTy
());
call_args
.
push_back
(
b_
->
CreateLoad
(
b_
->
getInt8PtrTy
(),
kvalue
,
r_arg
.
as_var
()
->
name
+
"_ptr_load"
));
}
else
if
(
r_arg
.
as_var
()
->
type
().
is_cpp_handle
()
||
r_arg
.
as_var
()
->
type
().
is_int
(
32
))
{
CHECK
(
global_args
.
count
(
r_arg
.
as_var
()
->
name
));
call_args
.
push_back
(
global_args
[
r_arg
.
as_var
()
->
name
]);
}
else
{
CINN_NOT_IMPLEMENTED
;
}
}
else
{
if
(
r_arg
.
type
().
is_bool
())
{
call_args
.
push_back
(
b_
->
getInt1
(
r_arg
.
as_bool
()));
}
else
if
(
r_arg
.
type
().
is_int
(
8
))
{
call_args
.
push_back
(
b_
->
getInt8
(
r_arg
.
as_int8
()));
}
else
if
(
r_arg
.
type
().
is_int
(
16
))
{
call_args
.
push_back
(
b_
->
getInt16
(
r_arg
.
as_int16
()));
}
else
if
(
r_arg
.
type
().
is_int
(
32
))
{
call_args
.
push_back
(
b_
->
getInt32
(
r_arg
.
as_int32
()));
}
else
if
(
r_arg
.
type
().
is_int
(
64
))
{
call_args
.
push_back
(
b_
->
getInt64
(
r_arg
.
as_int64
()));
}
else
if
(
r_arg
.
type
().
is_uint
(
8
))
{
call_args
.
push_back
(
b_
->
getInt8
(
r_arg
.
as_uint8
()));
}
else
if
(
r_arg
.
type
().
is_uint
(
16
))
{
call_args
.
push_back
(
b_
->
getInt16
(
r_arg
.
as_uint16
()));
}
else
if
(
r_arg
.
type
().
is_uint
(
32
))
{
call_args
.
push_back
(
b_
->
getInt32
(
r_arg
.
as_uint32
()));
}
else
if
(
r_arg
.
type
().
is_uint
(
64
))
{
call_args
.
push_back
(
b_
->
getInt64
(
r_arg
.
as_uint64
()));
}
else
if
(
r_arg
.
type
().
is_float
(
32
))
{
call_args
.
push_back
(
llvm
::
ConstantFP
::
get
(
b_
->
getFloatTy
(),
llvm
::
APFloat
(
r_arg
.
as_float
())));
}
else
if
(
r_arg
.
type
().
is_float
(
64
))
{
call_args
.
push_back
(
llvm
::
ConstantFP
::
get
(
b_
->
getDoubleTy
(),
llvm
::
APFloat
(
r_arg
.
as_double
())));
}
else
if
(
r_arg
.
type
().
is_bfloat16
())
{
call_args
.
push_back
(
llvm
::
ConstantFP
::
get
(
b_
->
getBFloatTy
(),
llvm
::
APFloat
(
static_cast
<
float
>
(
r_arg
.
as_bfloat16
()))));
}
else
if
(
r_arg
.
type
().
is_float16
())
{
call_args
.
push_back
(
llvm
::
ConstantFP
::
get
(
b_
->
getHalfTy
(),
llvm
::
APFloat
(
static_cast
<
float
>
(
r_arg
.
as_float16
()))));
}
else
{
CINN_NOT_IMPLEMENTED
;
}
}
}
b_
->
CreateCall
(
call_func
,
call_args
);
return
nullptr
;
}
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/codegen_cuda_host.h
View file @
01a10755
...
...
@@ -23,6 +23,8 @@
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
PD_DECLARE_bool
(
cinn_bucket_compile
);
namespace
cinn
{
namespace
backends
{
...
...
@@ -38,9 +40,16 @@ class CodeGenCUDA_Host : public CodeGenLLVM {
using
CodeGenLLVM
::
Visit
;
llvm
::
Value
*
Visit
(
const
ir
::
_LoweredFunc_
*
func
)
override
{
if
(
FLAGS_cinn_bucket_compile
)
{
return
LowerHostFunc
(
func
);
}
return
LowerGPUKernelLauncher
(
func
);
}
llvm
::
Value
*
Visit
(
const
ir
::
Call
*
op
)
override
{
return
LowerCUDAKernelCall
(
op
);
}
private:
/**
* Lower a CUDA kernel launcher.
...
...
@@ -56,6 +65,10 @@ class CodeGenCUDA_Host : public CodeGenLLVM {
*
*/
llvm
::
Value
*
LowerGPUKernelLauncher
(
const
ir
::
_LoweredFunc_
*
func
);
llvm
::
Value
*
LowerHostFunc
(
const
ir
::
_LoweredFunc_
*
func
);
llvm
::
Value
*
LowerCUDAKernelCall
(
const
ir
::
Call
*
op
);
};
}
// namespace backends
...
...
paddle/cinn/backends/codegen_cuda_util.cc
View file @
01a10755
...
...
@@ -15,16 +15,106 @@
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
PD_DECLARE_bool
(
cinn_bucket_compile
);
namespace
cinn
{
namespace
backends
{
std
::
tuple
<
ir
::
Module
,
ir
::
Module
>
SplitCudaAndHostModule
(
ir
::
Module
module
)
{
if
(
FLAGS_cinn_bucket_compile
)
{
detail
::
CollectBucketStrategyHostFunctionVisitor
visitor
(
module
->
name
);
Expr
expr
(
module
);
return
visitor
(
&
expr
);
}
detail
::
CollectHostFunctionVisitor
visitor
(
module
->
name
);
Expr
expr
(
module
);
return
visitor
(
&
expr
);
}
struct
PredicatePrinter
:
public
ir
::
IrPrinter
{
explicit
PredicatePrinter
(
std
::
ostream
&
os
)
:
ir
::
IrPrinter
(
os
)
{}
private:
void
Visit
(
const
ir
::
Add
*
x
)
{
PrintBinaryOp
(
"ADD"
,
x
);
}
void
Visit
(
const
ir
::
Sub
*
x
)
{
PrintBinaryOp
(
"SUB"
,
x
);
}
void
Visit
(
const
ir
::
Mul
*
x
)
{
PrintBinaryOp
(
"MUL"
,
x
);
}
void
Visit
(
const
ir
::
Div
*
x
)
{
PrintBinaryOp
(
"DIV"
,
x
);
}
void
Visit
(
const
ir
::
Mod
*
x
)
{
PrintBinaryOp
(
"MOD"
,
x
);
}
void
Visit
(
const
ir
::
EQ
*
x
)
{
PrintBinaryOp
(
"EQ"
,
x
);
}
void
Visit
(
const
ir
::
NE
*
x
)
{
PrintBinaryOp
(
"NE"
,
x
);
}
void
Visit
(
const
ir
::
LT
*
x
)
{
PrintBinaryOp
(
"LT"
,
x
);
}
void
Visit
(
const
ir
::
LE
*
x
)
{
PrintBinaryOp
(
"LE"
,
x
);
}
void
Visit
(
const
ir
::
GT
*
x
)
{
PrintBinaryOp
(
"GT"
,
x
);
}
void
Visit
(
const
ir
::
GE
*
x
)
{
PrintBinaryOp
(
"GE"
,
x
);
}
void
Visit
(
const
ir
::
And
*
x
)
{
PrintBinaryOp
(
"AND"
,
x
);
}
void
Visit
(
const
ir
::
Or
*
x
)
{
PrintBinaryOp
(
"OR"
,
x
);
}
template
<
typename
IRN
>
void
PrintBinaryOp
(
const
std
::
string
&
op
,
const
ir
::
BinaryOpNode
<
IRN
>
*
x
)
{
str_
+=
"_FPA_"
;
ir
::
IrPrinter
::
Visit
(
x
->
a
());
str_
+=
op
;
ir
::
IrPrinter
::
Visit
(
x
->
b
());
str_
+=
"_BPA_"
;
}
};
std
::
string
Predicate2String
(
ir
::
Expr
predicate
)
{
std
::
stringstream
ss
;
PredicatePrinter
cond_printer
(
ss
);
cond_printer
.
Print
(
predicate
);
return
ss
.
str
();
}
std
::
string
detail
::
CollectBucketStrategyHostFunctionVisitor
::
GenDeviceKernelName
(
const
std
::
string
&
fn_name
,
ir
::
Expr
predicate
)
{
std
::
string
cond_str
=
Predicate2String
(
predicate
);
VLOG
(
3
)
<<
"predicate string: "
<<
cond_str
;
return
fn_name
+
"__COND_"
+
cond_str
+
"__kernel"
;
}
void
detail
::
CollectBucketStrategyHostFunctionVisitor
::
ProcessLoweredFunc
(
ir
::
Expr
func
,
ir
::
Expr
predicate
)
{
ir
::
_LoweredFunc_
*
func_node
=
func
.
as_lowered_func
();
CHECK
(
func_node
);
if
(
!
func_node
->
cuda_axis_info
.
valid
())
{
func_node
->
cuda_axis_info
.
set_valid
(
true
);
}
// process device func
device_module_builder
.
AddFunctionWithoutOptim
(
CreateDeviceFunction
(
func
,
predicate
).
as_lowered_func_ref
());
// process host func
ir
::
Var
kernel_ptr
(
GenDeviceKernelName
(
func_node
->
name
,
predicate
),
type_of
<
std
::
string
>
());
ir
::
Expr
call_extern_api
=
ir
::
Call
::
Make
(
Void
(),
runtime
::
intrinsic
::
call_cuda_kernel
,
{
kernel_ptr
,
kernel_args_
,
kernel_args_num_
,
Expr
(
func_node
->
cuda_axis_info
.
grid_dim
(
0
)),
// grid_x
Expr
(
func_node
->
cuda_axis_info
.
grid_dim
(
1
)),
// grid_y
Expr
(
func_node
->
cuda_axis_info
.
grid_dim
(
2
)),
// grid_z
Expr
(
func_node
->
cuda_axis_info
.
block_dim
(
0
)),
// block_x
Expr
(
func_node
->
cuda_axis_info
.
block_dim
(
1
)),
// block_y
Expr
(
func_node
->
cuda_axis_info
.
block_dim
(
2
)),
// block_z
kernel_stream_
},
{},
ir
::
CallType
::
Extern
,
ir
::
FunctionRef
(),
0
);
buckets_
.
emplace_back
(
ir
::
IfThenElse
::
Make
(
predicate
,
call_extern_api
));
}
Expr
detail
::
CollectBucketStrategyHostFunctionVisitor
::
CreateDeviceFunction
(
ir
::
Expr
expr
,
ir
::
Expr
predicate
)
{
auto
copied
=
ir
::
ir_utils
::
IRCopy
(
expr
);
auto
*
lowered_func
=
copied
.
as_lowered_func
();
lowered_func
->
name
=
GenDeviceKernelName
(
lowered_func
->
name
,
predicate
);
return
copied
;
}
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/codegen_cuda_util.h
View file @
01a10755
...
...
@@ -22,8 +22,8 @@
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
namespace
cinn
{
namespace
backends
{
...
...
@@ -57,7 +57,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
device_module_builder
.
Build
());
}
pr
ivate
:
pr
otected
:
void
Visit
(
const
ir
::
_LoweredFunc_
*
op
,
Expr
*
expr
)
override
{
if
(
op
->
body
.
As
<
ir
::
Call
>
())
{
host_module_builder
.
AddFunctionWithoutOptim
(
expr
->
as_lowered_func_ref
());
...
...
@@ -127,7 +127,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
}
Expr
CreateDeviceFunctionGivenDeviceKernel
(
Expr
expr
)
{
auto
copied
=
optim
::
IRCopy
(
expr
);
auto
copied
=
ir
::
ir_utils
::
IRCopy
(
expr
);
auto
*
lowered_func
=
copied
.
as_lowered_func
();
lowered_func
->
name
=
GenDeviceKernelName
(
lowered_func
->
name
);
return
copied
;
...
...
@@ -137,11 +137,61 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
return
fn
+
"_kernel"
;
}
pr
ivate
:
pr
otected
:
ir
::
Module
::
Builder
host_module_builder
;
ir
::
Module
::
Builder
device_module_builder
;
};
struct
CollectBucketStrategyHostFunctionVisitor
:
public
CollectHostFunctionVisitor
{
explicit
CollectBucketStrategyHostFunctionVisitor
(
const
std
::
string
&
module_name
)
:
CollectHostFunctionVisitor
(
module_name
),
kernel_args_
(
KERNEL_ARGS
,
type_of
<
void
*>
()),
kernel_args_num_
(
KERNEL_ARGS_NUM
,
type_of
<
int
>
()),
kernel_stream_
(
KERNEL_STREAM
,
type_of
<
void
*>
())
{}
std
::
tuple
<
ir
::
Module
,
ir
::
Module
>
operator
()(
Expr
*
expr
)
{
ir
::
IRMutator
<>::
Visit
(
expr
,
expr
);
return
std
::
make_tuple
(
host_module_builder
.
Build
(),
device_module_builder
.
Build
());
}
private:
void
Visit
(
const
ir
::
_Module_
*
op
,
Expr
*
expr
)
{
CHECK_EQ
(
op
->
functions
.
size
(),
op
->
predicates
.
size
());
for
(
int
i
=
0
;
i
<
op
->
functions
.
size
();
++
i
)
{
ProcessLoweredFunc
(
op
->
functions
[
i
],
op
->
predicates
[
i
]);
}
std
::
vector
<
ir
::
Argument
>
arguments
=
{
ir
::
Argument
(
kernel_args_
,
ir
::
Argument
::
IO
::
kOutput
),
ir
::
Argument
(
kernel_args_num_
,
ir
::
Argument
::
IO
::
kInput
),
ir
::
Argument
(
kernel_stream_
,
ir
::
Argument
::
IO
::
kOutput
)};
ir
::
Expr
host_func
=
ir
::
_LoweredFunc_
::
Make
(
op
->
functions
[
0
].
as_lowered_func
()
->
name
,
arguments
,
ir
::
Block
::
Make
(
buckets_
),
{});
host_module_builder
.
AddFunctionWithoutOptim
(
host_func
.
as_lowered_func_ref
());
}
void
ProcessLoweredFunc
(
ir
::
Expr
func
,
ir
::
Expr
predicate
);
Expr
CreateDeviceFunction
(
ir
::
Expr
expr
,
ir
::
Expr
predicate
);
inline
std
::
string
GenDeviceKernelName
(
const
std
::
string
&
fn_name
,
ir
::
Expr
predicate
);
private:
std
::
vector
<
ir
::
Expr
>
buckets_
;
ir
::
Var
kernel_args_
;
ir
::
Var
kernel_args_num_
;
ir
::
Var
kernel_stream_
;
};
}
// namespace detail
}
// namespace backends
...
...
paddle/cinn/backends/compiler.cc
View file @
01a10755
...
...
@@ -18,7 +18,9 @@
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/ir/ir_printer.h"
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
...
...
@@ -29,27 +31,83 @@
#include "paddle/cinn/runtime/flags.h"
#endif
DECLARE_string
(
cinn_source_code_save_path
);
DECLARE_string
(
cinn_dump_group_lowered_func
);
DECLARE_string
(
cinn_dump_group_source_code
);
DECLARE_string
(
cinn_dump_group_ptx
);
DECLARE_string
(
cinn_dump_group_instruction
);
PD_
DECLARE_string
(
cinn_source_code_save_path
);
PD_
DECLARE_string
(
cinn_dump_group_lowered_func
);
PD_
DECLARE_string
(
cinn_dump_group_source_code
);
PD_
DECLARE_string
(
cinn_dump_group_ptx
);
PD_
DECLARE_string
(
cinn_dump_group_instruction
);
namespace
cinn
{
namespace
backends
{
using
ir
::
Module
;
using
CompilationStatus
=
hlir
::
framework
::
CompilationStatus
;
static
constexpr
int
DebugLogMaxLen
=
30000
;
void
CompilationInfoDumper
::
DumpLoweredFuncByGroupIndex
(
const
ir
::
LoweredFunc
&
lowered_func
,
const
int
gidx
,
const
int
device_id
)
{
if
(
FLAGS_cinn_dump_group_lowered_func
.
empty
()
||
lowered_func
.
get
()
==
nullptr
)
{
return
;
}
std
::
stringstream
content
;
content
<<
lowered_func
;
Dump
(
FLAGS_cinn_dump_group_lowered_func
,
gidx
,
device_id
,
"lowered_function.txt"
,
content
.
str
());
}
void
CompilationInfoDumper
::
DumpSourceCodeByGroupIndex
(
const
std
::
string
&
source_code
,
const
int
gidx
,
const
int
device_id
)
{
if
(
FLAGS_cinn_dump_group_source_code
.
empty
())
{
return
;
}
Dump
(
FLAGS_cinn_dump_group_source_code
,
gidx
,
device_id
,
"source_code.cu"
,
source_code
);
}
void
CompilationInfoDumper
::
DumpPtxCodeByGroupIndex
(
const
std
::
string
&
source_ptx
,
const
int
gidx
,
const
int
device_id
)
{
if
(
FLAGS_cinn_dump_group_ptx
.
empty
())
{
return
;
}
Dump
(
FLAGS_cinn_dump_group_ptx
,
gidx
,
device_id
,
"source_ptx.ptx"
,
source_ptx
);
}
void
CompilationInfoDumper
::
DumpInstructionByGroupIndex
(
const
std
::
unique_ptr
<
cinn
::
hlir
::
framework
::
Instruction
>&
instr
,
const
int
gidx
,
const
int
device_id
)
{
if
(
FLAGS_cinn_dump_group_instruction
.
empty
()
||
instr
.
get
()
==
nullptr
)
{
return
;
}
Dump
(
FLAGS_cinn_dump_group_instruction
,
gidx
,
device_id
,
"instruction.txt"
,
instr
->
DumpInstruction
());
}
void
CompilationInfoDumper
::
DumpLoweredFunc
()
{
if
(
FLAGS_cinn_dump_group_lowered_func
.
empty
())
{
return
;
}
for
(
int
idx
=
0
;
idx
<
info_
.
lowered_funcs
.
s
ize
();
++
idx
)
{
for
(
int
idx
=
0
;
idx
<
info_
.
S
ize
();
++
idx
)
{
std
::
stringstream
content
;
content
<<
info_
.
lowered_funcs
[
idx
].
front
();
if
(
info_
.
Status
(
idx
)
>
CompilationStatus
::
LOWERING_FAIL
)
{
content
<<
info_
.
LoweredFuncs
(
idx
).
front
();
}
else
{
content
<<
"[No lowered func generated]
\n\n
"
<<
info_
.
Message
(
idx
);
}
Dump
(
FLAGS_cinn_dump_group_lowered_func
,
idx
,
device_id_
,
"lowered_function.txt"
,
content
.
str
());
}
...
...
@@ -59,11 +117,18 @@ void CompilationInfoDumper::DumpSourceCode() {
if
(
FLAGS_cinn_dump_group_source_code
.
empty
())
{
return
;
}
for
(
int
idx
=
0
;
idx
<
info_
.
source_codes
.
size
();
++
idx
)
{
for
(
int
idx
=
0
;
idx
<
info_
.
Size
();
++
idx
)
{
std
::
string
dump_str
;
if
(
info_
.
Status
(
idx
)
>
CompilationStatus
::
CODEGEN_JIT_FAIL
)
{
dump_str
=
info_
.
SourceCode
(
idx
);
}
else
{
dump_str
=
"[No source code generated]
\n\n
"
+
info_
.
Message
(
idx
);
}
Dump
(
FLAGS_cinn_dump_group_source_code
,
idx
,
device_id_
,
"source_code.cu"
,
info_
.
source_codes
[
idx
]
);
dump_str
);
}
}
...
...
@@ -71,11 +136,15 @@ void CompilationInfoDumper::DumpPtxCode() {
if
(
FLAGS_cinn_dump_group_ptx
.
empty
())
{
return
;
}
for
(
int
idx
=
0
;
idx
<
info_
.
source_ptxs
.
size
();
++
idx
)
{
Dump
(
FLAGS_cinn_dump_group_ptx
,
idx
,
"source_ptx.ptx"
,
info_
.
source_ptxs
[
idx
]);
for
(
int
idx
=
0
;
idx
<
info_
.
Size
();
++
idx
)
{
std
::
string
dump_str
;
if
(
info_
.
Status
(
idx
)
>
CompilationStatus
::
CODEGEN_JIT_FAIL
)
{
dump_str
=
info_
.
SourcePtx
(
idx
);
}
else
{
dump_str
=
"[No source ptxs generated]
\n\n
"
+
info_
.
Message
(
idx
);
}
Dump
(
FLAGS_cinn_dump_group_ptx
,
idx
,
device_id_
,
"source_ptx.ptx"
,
dump_str
);
}
}
...
...
@@ -83,20 +152,28 @@ void CompilationInfoDumper::DumpInstruction() {
if
(
FLAGS_cinn_dump_group_instruction
.
empty
())
{
return
;
}
for
(
int
idx
=
0
;
idx
<
info_
.
instructions
.
size
();
++
idx
)
{
for
(
int
idx
=
0
;
idx
<
info_
.
RuntimeInstructions
().
size
();
++
idx
)
{
std
::
string
dump_str
;
if
(
info_
.
RuntimeInstruction
(
idx
).
get
()
!=
nullptr
)
{
dump_str
=
info_
.
RuntimeInstruction
(
idx
)
->
DumpInstruction
();
}
else
{
dump_str
=
"[No instruction generated]
\n\n
"
+
info_
.
Message
(
idx
);
}
Dump
(
FLAGS_cinn_dump_group_instruction
,
idx
,
device_id_
,
"instruction.txt"
,
info_
.
instructions
[
idx
]
->
D
ump
In
str
uction
()
);
d
ump
_
str
);
}
}
void
CompilationInfoDumper
::
Dump
(
const
std
::
string
&
base_path
,
const
int
idx
,
const
int
device_id
,
const
std
::
string
&
file_name
,
const
std
::
string
&
content
)
{
auto
dump_path
=
utils
::
StringFormat
(
"%s
/fusion_group_%d"
,
base_path
.
c_str
(),
idx
);
auto
dump_path
=
utils
::
StringFormat
(
"%s/device_%d
/fusion_group_%d"
,
base_path
.
c_str
(),
device_id
,
idx
);
if
(
!
hlir
::
framework
::
MakeDirectory
(
dump_path
,
S_IRWXU
|
S_IRGRP
|
S_IXGRP
|
S_IROTH
|
S_IXOTH
))
{
LOG
(
WARNING
)
<<
"Failed to make directory:
\"
"
<<
dump_path
...
...
@@ -227,6 +304,8 @@ void Compiler::CompileCudaModule(const Module& module,
auto
fn_kernel
=
cuda_module_
->
GetFunction
(
0
,
kernel_fn_name
);
CHECK
(
fn_kernel
);
fn_ptr_
.
push_back
(
reinterpret_cast
<
void
*>
(
fn_kernel
));
symbols
.
RegisterVar
(
kernel_fn_name
+
"_ptr_"
,
reinterpret_cast
<
void
*>
(
fn_kernel
));
}
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
29
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment