Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
565
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1133 additions
and
920 deletions
+1133
-920
paddle/cinn/ir/test/CMakeLists.txt
paddle/cinn/ir/test/CMakeLists.txt
+6
-0
paddle/cinn/ir/test/collect_ir_nodes_test.cc
paddle/cinn/ir/test/collect_ir_nodes_test.cc
+5
-4
paddle/cinn/ir/test/ir_compare_test.cc
paddle/cinn/ir/test/ir_compare_test.cc
+7
-11
paddle/cinn/ir/test/ir_copy_test.cc
paddle/cinn/ir/test/ir_copy_test.cc
+5
-4
paddle/cinn/ir/test/ir_printer_test.cc
paddle/cinn/ir/test/ir_printer_test.cc
+1
-1
paddle/cinn/ir/test/ir_verify_test.cc
paddle/cinn/ir/test/ir_verify_test.cc
+6
-4
paddle/cinn/ir/test/schedule_block_graph_test.cc
paddle/cinn/ir/test/schedule_block_graph_test.cc
+69
-16
paddle/cinn/ir/test/schedule_desc_test.cc
paddle/cinn/ir/test/schedule_desc_test.cc
+13
-8
paddle/cinn/ir/test/st_shape_group_scheduler_test.cc
paddle/cinn/ir/test/st_shape_group_scheduler_test.cc
+767
-0
paddle/cinn/ir/test/tensor_test.cc
paddle/cinn/ir/test/tensor_test.cc
+1
-1
paddle/cinn/ir/utils/CMakeLists.txt
paddle/cinn/ir/utils/CMakeLists.txt
+2
-4
paddle/cinn/ir/utils/ir_compare.cc
paddle/cinn/ir/utils/ir_compare.cc
+73
-30
paddle/cinn/ir/utils/ir_compare.h
paddle/cinn/ir/utils/ir_compare.h
+14
-6
paddle/cinn/ir/utils/ir_copy.cc
paddle/cinn/ir/utils/ir_copy.cc
+18
-8
paddle/cinn/ir/utils/ir_copy.h
paddle/cinn/ir/utils/ir_copy.h
+3
-3
paddle/cinn/ir/utils/ir_nodes_collector.cc
paddle/cinn/ir/utils/ir_nodes_collector.cc
+114
-3
paddle/cinn/ir/utils/ir_nodes_collector.h
paddle/cinn/ir/utils/ir_nodes_collector.h
+20
-1
paddle/cinn/ir/utils/ir_printer.cc
paddle/cinn/ir/utils/ir_printer.cc
+0
-711
paddle/cinn/ir/utils/ir_printer.h
paddle/cinn/ir/utils/ir_printer.h
+0
-98
paddle/cinn/ir/utils/ir_replace.cc
paddle/cinn/ir/utils/ir_replace.cc
+9
-7
No files found.
Too many changes to show.
To preserve performance only
565 of 565+
files are displayed.
Plain diff
Email patch
paddle/cinn/ir/test/CMakeLists.txt
View file @
01a10755
...
...
@@ -19,3 +19,9 @@ cinn_cc_test(test_ir_compare SRCS ir_compare_test.cc DEPS cinncore)
cinn_cc_test
(
test_ir_copy SRCS ir_copy_test.cc DEPS cinncore
)
cinn_cc_test
(
test_schedule_block_graph SRCS schedule_block_graph_test.cc DEPS
cinncore
)
if
(
WITH_CUDA
)
cinn_cc_test
(
test_static_shape_group_scheduler SRCS st_shape_group_scheduler_test.cc
DEPS cinncore decomposer_test_helper
)
endif
()
paddle/cinn/ir/test/collect_ir_nodes_test.cc
View file @
01a10755
...
...
@@ -19,6 +19,7 @@
namespace
cinn
{
namespace
ir
{
namespace
ir_utils
{
TEST
(
CollectIRNodes
,
basic0
)
{
Expr
C
=
Expr
(
1
)
+
2
;
...
...
@@ -41,15 +42,15 @@ TEST(CollectIRNodes, basic) {
auto
C
=
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
,
j
)
+
B
(
i
,
j
);
},
"C"
);
a
uto
st
a
ge
s
=
CreateStages
({
C
});
ast
_
ge
n_ius
::
TensorGroup
tensor_group
({
C
});
auto
fn
=
Lower
(
"fn"
,
stages
,
{
A
,
B
,
C
});
auto
fn
=
Lower
ToAst
(
"fn"
,
{
A
,
B
,
C
}
,
&
tensor_group
);
LOG
(
INFO
)
<<
"fn:
\n
"
<<
fn
;
auto
tensors
=
CollectIRNodes
(
fn
,
[](
const
Expr
*
x
)
{
return
x
->
as_tensor
();
});
ASSERT_EQ
(
tensors
.
size
(),
5
UL
);
ASSERT_EQ
(
tensors
.
size
(),
3
UL
);
auto
fn_body
=
fn
.
As
<
ir
::
_LoweredFunc_
>
()
->
body
;
LOG
(
INFO
)
<<
"fn.body:
\n
"
<<
fn_body
;
...
...
@@ -57,6 +58,6 @@ TEST(CollectIRNodes, basic) {
CollectIRNodes
(
fn_body
,
[](
const
Expr
*
x
)
{
return
x
->
as_tensor
();
});
auto
exprs
=
CollectIRNodes
(
fn_body
,
[](
const
Expr
*
x
)
{
return
x
;
});
}
}
// namespace ir_utils
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/test/ir_compare_test.cc
View file @
01a10755
...
...
@@ -23,7 +23,7 @@
namespace
cinn
{
namespace
ir
{
namespace
ir_utils
{
TEST
(
TestIrCompare
,
SingleFunction
)
{
Target
target
=
common
::
DefaultHostTarget
();
...
...
@@ -128,20 +128,16 @@ TEST(TestIrCompare, SingleFunction) {
ASSERT_EQ
(
func2_str
,
utils
::
GetStreamCnt
(
funcs_2
.
front
()));
ASSERT_EQ
(
func3_str
,
utils
::
GetStreamCnt
(
funcs_3
.
front
()));
IrEqualVisitor
compartor
;
// they are different at the name of root ScheduleBlock
ASSERT_TRUE
(
compartor
.
Compare
(
funcs_1
.
front
(),
funcs_2
.
front
()));
ASSERT_TRUE
(
IR
Compare
(
funcs_1
.
front
(),
funcs_2
.
front
()));
// compare with itself
ASSERT_TRUE
(
compartor
.
Compare
(
funcs_1
.
front
(),
funcs_1
.
front
()));
IrEqualVisitor
compartor_allow_suffix_diff
(
true
);
ASSERT_TRUE
(
IRCompare
(
funcs_1
.
front
(),
funcs_1
.
front
()));
// they are euqal if allowing suffix of name different
ASSERT_TRUE
(
compartor_allow_suffix_diff
.
Compare
(
funcs_1
.
front
(),
funcs_2
.
front
()));
ASSERT_TRUE
(
IRCompare
(
funcs_1
.
front
(),
funcs_2
.
front
(),
true
));
ASSERT_FALSE
(
compartor
.
Compare
(
funcs_1
.
front
(),
funcs_3
.
front
()));
ASSERT_FALSE
(
compartor_allow_suffix_diff
.
Compare
(
funcs_1
.
front
(),
funcs_3
.
front
()));
ASSERT_FALSE
(
IRCompare
(
funcs_1
.
front
(),
funcs_3
.
front
()));
ASSERT_FALSE
(
IRCompare
(
funcs_1
.
front
(),
funcs_3
.
front
(),
true
));
}
}
// namespace ir_utils
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/test/ir_copy_test.cc
View file @
01a10755
...
...
@@ -16,16 +16,17 @@
#include <gtest/gtest.h>
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace
cinn
{
namespace
optim
{
namespace
ir
{
namespace
ir_utils
{
TEST
(
IrCopy
,
basic
)
{
Expr
a
(
1.
f
);
auto
aa
=
IRCopy
(
a
);
LOG
(
INFO
)
<<
"aa "
<<
aa
;
}
}
// namespace
optim
}
// namespace ir_utils
}
// namespace
ir
}
// namespace cinn
paddle/cinn/ir/test/ir_printer_test.cc
View file @
01a10755
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include <gtest/gtest.h>
...
...
paddle/cinn/ir/test/ir_verify_test.cc
View file @
01a10755
...
...
@@ -18,12 +18,14 @@
#include "paddle/cinn/ir/op/ir_operators.h"
namespace
cinn
::
ir
{
namespace
cinn
{
namespace
ir
{
namespace
ir_utils
{
TEST
(
IrVerify
,
basic
)
{
Expr
a
(
1
);
Expr
b
(
1
);
IrVerify
(
a
+
b
);
}
}
// namespace cinn::ir
}
// namespace ir_utils
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/test/schedule_block_graph_test.cc
View file @
01a10755
...
...
@@ -20,6 +20,8 @@
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
PD_DECLARE_bool
(
cinn_new_group_scheduler
);
namespace
cinn
{
namespace
ir
{
...
...
@@ -38,7 +40,8 @@ IRSchedule MakeIRSchedule(frontend::Program* program) {
"inferdtype"
);
auto
&
shape_dict
=
graph
->
GetMutableAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
hlir
::
framework
::
OpLowerer
op_lowerer
(
dtype_dict
,
shape_dict
,
target
);
auto
op_lowerer
=
hlir
::
framework
::
CreateOpLowerer
(
dtype_dict
,
shape_dict
,
target
);
std
::
vector
<
LoweredFunc
>
lowered_funcs
=
op_lowerer
.
Lower
(
graph
->
fusion_groups
.
front
(),
false
,
false
);
...
...
@@ -94,10 +97,11 @@ frontend::Program CreateReduceProgram() {
}
TEST
(
ScheduleBlockGraph
,
elementwise
)
{
Context
::
Global
().
ResetNameId
();
frontend
::
Program
program
=
CreateElementwiseProgram
();
IRSchedule
ir_sch
=
MakeIRSchedule
(
&
program
);
ScheduleBlockGraph
sbg
(
ir_sch
);
LOG
(
INFO
)
<<
GetIR
(
ir_sch
);
ScheduleBlockGraph
sbg
(
ir_sch
);
LOG
(
INFO
)
<<
sbg
.
Visualize
();
CHECK_EQ
(
sbg
.
BlockIdsInOrder
().
size
(),
6
);
CHECK_EQ
(
sbg
.
nodes
().
size
(),
6
);
...
...
@@ -135,24 +139,73 @@ TEST(ScheduleBlockGraph, elementwise) {
#ifdef CINN_WITH_CUDA
TEST
(
ScheduleBlockGraph
,
reduce
)
{
if
(
FLAGS_cinn_new_group_scheduler
)
{
Context
::
Global
().
ResetNameId
();
frontend
::
Program
program
=
CreateReduceProgram
();
IRSchedule
ir_sch
=
MakeIRSchedule
(
&
program
);
ScheduleBlockGraph
sbg
(
ir_sch
);
LOG
(
INFO
)
<<
GetIR
(
ir_sch
);
LOG
(
INFO
)
<<
sbg
.
Visualize
();
CHECK_EQ
(
sbg
.
BlockIdsInOrder
().
size
(),
8
);
CHECK_EQ
(
sbg
.
nodes
().
size
(),
8
);
CHECK_EQ
(
sbg
.
BlockIdsInOrder
().
size
(),
5
);
CHECK_EQ
(
sbg
.
nodes
().
size
(),
5
);
ScheduleBlockNode
*
v_reduce_init
=
sbg
.
RetrieveNode
(
"var_
48
__reduce_init"
);
ScheduleBlockNode
*
v_reduce_init
=
sbg
.
RetrieveNode
(
"var_
2
__reduce_init"
);
CHECK
(
v_reduce_init
);
CHECK_EQ
(
v_reduce_init
->
UpstreamNodes
().
size
(),
0
);
CHECK_EQ
(
v_reduce_init
->
DownstreamNodes
().
size
(),
3
);
ScheduleBlockNode
*
v
=
sbg
.
RetrieveNode
(
"var_
48
"
);
ScheduleBlockNode
*
v
=
sbg
.
RetrieveNode
(
"var_
2
"
);
CHECK
(
v
);
CHECK_EQ
(
v
->
UpstreamNodes
().
size
(),
5
);
CHECK_EQ
(
v
->
UpstreamNodes
().
size
(),
2
);
CHECK_EQ
(
v
->
DownstreamNodes
().
size
(),
2
);
std
::
vector
<
std
::
string
>
reverse_dfs_topo_order_ids
;
sbg
.
DFSTopoWalk
(
[
&
reverse_dfs_topo_order_ids
](
const
ScheduleBlockNode
*
node
)
{
reverse_dfs_topo_order_ids
.
push_back
(
node
->
id
());
});
for
(
const
std
::
string
&
id
:
reverse_dfs_topo_order_ids
)
{
LOG
(
INFO
)
<<
id
;
}
CHECK_EQ
(
reverse_dfs_topo_order_ids
.
size
(),
5
);
std
::
vector
<
std
::
string
>
dfs_topo_order_ids
;
sbg
.
DFSTopoWalk
(
[
&
dfs_topo_order_ids
](
const
ScheduleBlockNode
*
node
)
{
dfs_topo_order_ids
.
push_back
(
node
->
id
());
},
false
);
for
(
const
std
::
string
&
id
:
dfs_topo_order_ids
)
{
LOG
(
INFO
)
<<
id
;
}
CHECK_EQ
(
dfs_topo_order_ids
.
size
(),
5
);
}
}
TEST
(
ScheduleBlockGraph
,
arg_max
)
{
Context
::
Global
().
ResetNameId
();
frontend
::
NetBuilder
builder
(
"net_builder"
);
auto
x
=
builder
.
CreateInput
(
Float
(
32
),
{
8
,
16
},
"X"
);
auto
y
=
builder
.
Argmax
(
x
,
0
);
frontend
::
Program
program
=
builder
.
Build
();
IRSchedule
ir_sch
=
MakeIRSchedule
(
&
program
);
LOG
(
INFO
)
<<
GetIR
(
ir_sch
);
ScheduleBlockGraph
sbg
(
ir_sch
);
LOG
(
INFO
)
<<
sbg
.
Visualize
();
CHECK_EQ
(
sbg
.
BlockIdsInOrder
().
size
(),
3
);
CHECK_EQ
(
sbg
.
nodes
().
size
(),
3
);
ScheduleBlockNode
*
v0_idx
=
sbg
.
RetrieveNode
(
"var_0_index"
);
CHECK
(
v0_idx
);
CHECK_EQ
(
v0_idx
->
UpstreamNodes
().
size
(),
1
);
CHECK_EQ
(
v0_idx
->
DownstreamNodes
().
size
(),
1
);
ScheduleBlockNode
*
v0
=
sbg
.
RetrieveNode
(
"var_0"
);
CHECK
(
v0
);
CHECK_EQ
(
v0
->
UpstreamNodes
().
size
(),
2
);
CHECK_EQ
(
v0
->
DownstreamNodes
().
size
(),
0
);
std
::
vector
<
std
::
string
>
reverse_dfs_topo_order_ids
;
sbg
.
DFSTopoWalk
([
&
reverse_dfs_topo_order_ids
](
const
ScheduleBlockNode
*
node
)
{
reverse_dfs_topo_order_ids
.
push_back
(
node
->
id
());
...
...
@@ -160,7 +213,7 @@ TEST(ScheduleBlockGraph, reduce) {
for
(
const
std
::
string
&
id
:
reverse_dfs_topo_order_ids
)
{
LOG
(
INFO
)
<<
id
;
}
CHECK_EQ
(
reverse_dfs_topo_order_ids
.
size
(),
8
);
CHECK_EQ
(
reverse_dfs_topo_order_ids
.
size
(),
3
);
std
::
vector
<
std
::
string
>
dfs_topo_order_ids
;
sbg
.
DFSTopoWalk
(
...
...
@@ -171,7 +224,7 @@ TEST(ScheduleBlockGraph, reduce) {
for
(
const
std
::
string
&
id
:
dfs_topo_order_ids
)
{
LOG
(
INFO
)
<<
id
;
}
CHECK_EQ
(
dfs_topo_order_ids
.
size
(),
8
);
CHECK_EQ
(
dfs_topo_order_ids
.
size
(),
3
);
}
#endif
...
...
paddle/cinn/ir/test/schedule_desc_test.cc
View file @
01a10755
...
...
@@ -19,9 +19,9 @@
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/utils/string.h"
#include "paddle/cinn/utils/type_defs.h"
...
...
@@ -95,7 +95,7 @@ std::vector<ir::LoweredFunc> LowerCompute(
IRSchedule
MakeIRSchedule
(
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
)
{
std
::
vector
<
Expr
>
exprs
;
for
(
auto
&&
func
:
lowered_funcs
)
{
exprs
.
emplace_back
(
optim
::
IRCopy
(
func
->
body
));
exprs
.
emplace_back
(
ir
::
ir_utils
::
IRCopy
(
func
->
body
));
}
return
ir
::
IRSchedule
(
ir
::
ModuleExpr
(
exprs
));
}
...
...
@@ -106,10 +106,11 @@ std::string SourceCodeGen(const ModuleExpr& module_expr,
const
Target
&
target
)
{
auto
exprs
=
module_expr
.
GetExprs
();
CHECK_EQ
(
exprs
.
size
(),
lowered_funcs
.
size
())
<<
"size of func is not euqal"
;
std
::
vector
<
ir
::
LoweredFunc
>
updated_funcs
=
optim
::
IRCopy
(
lowered_funcs
);
std
::
vector
<
ir
::
LoweredFunc
>
updated_funcs
=
ir
::
ir_utils
::
IRCopy
(
lowered_funcs
);
Module
::
Builder
builder
(
"test_module"
,
target
);
for
(
auto
i
=
0
;
i
<
lowered_funcs
.
size
();
++
i
)
{
updated_funcs
[
i
]
->
body
=
optim
::
IRCopy
(
exprs
.
at
(
i
));
updated_funcs
[
i
]
->
body
=
ir
::
ir_utils
::
IRCopy
(
exprs
.
at
(
i
));
builder
.
AddFunction
(
updated_funcs
[
i
]);
}
auto
module
=
builder
.
Build
();
...
...
@@ -778,6 +779,7 @@ TEST_F(TestScheduleDesc, StepKind_ReverseComputeInline) {
CheckReplayResult
(
ir_sch
,
ir_sch
.
GetTraceDesc
());
}
#ifdef CINN_WITH_CUDA
TEST_F
(
TestScheduleDesc
,
StepKind_Bind
)
{
lowered_funcs
=
LowerCompute
({
32
,
128
},
target
);
ir
::
IRSchedule
ir_sch
=
MakeIRSchedule
(
lowered_funcs
);
...
...
@@ -793,6 +795,7 @@ TEST_F(TestScheduleDesc, StepKind_Bind) {
CheckReplayResult
(
ir_sch
,
trace
);
CheckReplayResult
(
ir_sch
,
ir_sch
.
GetTraceDesc
());
}
#endif
TEST_F
(
TestScheduleDesc
,
StepKind_Rfactor
)
{
Expr
M
(
32
);
...
...
@@ -839,12 +842,14 @@ TEST_F(TestScheduleDesc, StepKind_MergeExprs) {
auto
funcs_1
=
LowerCompute
({
32
,
32
,
32
},
target
,
true
,
"elementwise-add_const"
);
ir
::
IRSchedule
ir_sch
=
ir
::
IRSchedule
(
ir
::
ModuleExpr
(
{
optim
::
IRCopy
(
funcs_0
[
0
]
->
body
),
optim
::
IRCopy
(
funcs_0
[
0
]
->
body
)}));
ir
::
IRSchedule
ir_sch
=
ir
::
IRSchedule
(
ir
::
ModuleExpr
({
ir
::
ir_utils
::
IRCopy
(
funcs_0
[
0
]
->
body
),
ir
::
ir_utils
::
IRCopy
(
funcs_0
[
0
]
->
body
)}));
ir_sch
.
MergeExprs
();
trace
.
Append
(
ScheduleDesc
::
Step
(
"MergeExprs"
,
{},
{},
{}));
ir
::
IRSchedule
replay_sch
=
ir
::
IRSchedule
(
ir
::
ModuleExpr
(
{
optim
::
IRCopy
(
funcs_0
[
0
]
->
body
),
optim
::
IRCopy
(
funcs_0
[
0
]
->
body
)}));
ir
::
IRSchedule
replay_sch
=
ir
::
IRSchedule
(
ir
::
ModuleExpr
({
ir
::
ir_utils
::
IRCopy
(
funcs_0
[
0
]
->
body
),
ir
::
ir_utils
::
IRCopy
(
funcs_0
[
0
]
->
body
)}));
trace
.
Replay
(
&
replay_sch
);
auto
lhs_exprs
=
ir_sch
.
GetModule
().
GetExprs
();
...
...
paddle/cinn/ir/test/st_shape_group_scheduler_test.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h"
#include <gtest/gtest.h>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/frontend/decomposer/test_helper.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
PD_DECLARE_bool
(
cinn_new_group_scheduler
);
namespace
cinn
{
namespace
ir
{
using
frontend
::
NetBuilder
;
using
frontend
::
RunDecomposer
;
void
Compile
(
NetBuilder
*
net_builder
)
{
auto
program
=
net_builder
->
Build
();
auto
target
=
common
::
DefaultTarget
();
RunDecomposer
(
&
program
,
target
);
auto
graph
=
std
::
make_shared
<
hlir
::
framework
::
Graph
>
(
program
,
target
);
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"OpFusionPass"
);
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"FusionMergePass"
);
CHECK_EQ
(
graph
->
fusion_groups
.
size
(),
1
);
auto
&
dtype_dict
=
graph
->
GetMutableAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
Type
>>
(
"inferdtype"
);
auto
&
shape_dict
=
graph
->
GetMutableAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
auto
op_lowerer
=
hlir
::
framework
::
CreateOpLowerer
(
dtype_dict
,
shape_dict
,
target
);
for
(
auto
&
fusion_group
:
graph
->
fusion_groups
)
{
std
::
vector
<
ir
::
LoweredFunc
>
lowered_funcs
=
op_lowerer
.
Lower
(
fusion_group
,
/* apply_op_schedule = */
true
,
/* apply_group_schedule = */
false
);
CHECK_EQ
(
lowered_funcs
.
size
(),
1
);
VLOG
(
1
)
<<
"without group schedule, lowered_func: "
<<
lowered_funcs
.
front
();
FLAGS_cinn_new_group_scheduler
=
true
;
lowered_funcs
=
op_lowerer
.
Lower
(
fusion_group
,
/* apply_op_schedule = */
true
,
/* apply_group_schedule = */
true
);
CHECK_EQ
(
lowered_funcs
.
size
(),
1
);
VLOG
(
1
)
<<
"after group schedule, lowered_func: "
<<
lowered_funcs
.
front
();
}
}
void
CheckAccuracy
(
NetBuilder
*
net_builder
,
const
std
::
vector
<
std
::
string
>&
input_names
)
{
FLAGS_cinn_new_group_scheduler
=
true
;
auto
program
=
net_builder
->
Build
();
auto
target
=
common
::
DefaultTarget
();
auto
graph
=
std
::
make_shared
<
hlir
::
framework
::
Graph
>
(
program
,
target
);
hlir
::
framework
::
ApplyPasses
(
graph
.
get
(),
{
"OpFusionPass"
,
"FusionMergePass"
});
VLOG
(
1
)
<<
"Before CheckFusionAccuracyPass:
\n
"
<<
graph
->
DebugGroupedGraph
(
std
::
unordered_set
<
std
::
string
>
{});
hlir
::
framework
::
ApplyPasses
(
graph
.
get
(),
{
"CheckFusionAccuracyPass"
,
"TransToCustomCallPass"
});
VLOG
(
1
)
<<
"After CheckFusionAccuracyPass:
\n
"
<<
graph
->
DebugGroupedGraph
(
std
::
unordered_set
<
std
::
string
>
{});
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
for
(
size_t
i
=
0
;
i
<
input_names
.
size
();
++
i
)
{
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
input_names
[
i
]);
auto
tensor
=
scope
->
GetTensor
(
input_names
[
i
]);
std
::
vector
<
float
>
vec
;
frontend
::
InitRandomVector
<
float
>
(
&
vec
,
tensor
->
shape
().
numel
(),
0.0
f
,
1.0
f
);
frontend
::
CopyFromVector
<
float
>
(
vec
,
tensor
,
target
);
}
auto
runtime_program
=
gc
.
Build
();
runtime_program
->
Execute
();
}
// Each unittest below tests a single reduce,
// these unittests are only used to observe the generated IR and debug.
// Accuracy testing is guaranteed by Python unittests named
// test_reduce_op_xxx.py.
TEST
(
GROUP_SCHEDULER
,
last_reduce_only_1
)
{
NetBuilder
net_builder
(
"last_reduce_only_1"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
128
,
64
,
32
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
2
});
};
CreateModel
();
Compile
(
&
net_builder
);
}
TEST
(
GROUP_SCHEDULER
,
last_reduce_only_2
)
{
NetBuilder
net_builder
(
"last_reduce_only_2"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
1024
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
0
});
};
CreateModel
();
Compile
(
&
net_builder
);
}
TEST
(
GROUP_SCHEDULER
,
last_reduce_only_3
)
{
NetBuilder
net_builder
(
"last_reduce_only_3"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
512
,
256
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
1
});
};
CreateModel
();
Compile
(
&
net_builder
);
}
TEST
(
GROUP_SCHEDULER
,
non_last_reduce_only_1
)
{
NetBuilder
net_builder
(
"non_last_reduce_only_1"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
10
,
10
,
10
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
0
,
1
},
/* keep_dim = */
true
);
};
CreateModel
();
Compile
(
&
net_builder
);
}
TEST
(
GROUP_SCHEDULER
,
non_last_reduce_only_2
)
{
NetBuilder
net_builder
(
"non_last_reduce_only_2"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
64
,
32
,
16
,
8
,
4
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
1
,
2
,
3
},
/* keep_dim = */
true
);
};
CreateModel
();
Compile
(
&
net_builder
);
}
TEST
(
GROUP_SCHEDULER
,
shuffle_reduce_only_1
)
{
NetBuilder
net_builder
(
"shuffle_reduce_only_1"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
32
,
32
,
32
,
32
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
0
,
2
,
3
});
};
CreateModel
();
Compile
(
&
net_builder
);
}
TEST
(
GROUP_SCHEDULER
,
shuffle_reduce_only_2
)
{
NetBuilder
net_builder
(
"shuffle_reduce_only_2"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
32
,
64
,
56
,
56
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
0
,
2
,
3
});
};
CreateModel
();
Compile
(
&
net_builder
);
}
// Each of the following unittest tests a basic pattern composed of multiple
// basic op. And apply accuracy checks to ensure that the results of fusion
// groups and independently running each op are consistent.
TEST
(
GROUP_SCHEDULER
,
elementwise_1
)
{
int
h
=
128
,
w
=
128
;
NetBuilder
net_builder
(
"elementwise_1"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
D
=
net_builder
.
Add
(
B
,
C
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_2
)
{
int
h
=
128
,
w
=
128
;
NetBuilder
net_builder
(
"elementwise_2"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
D
=
net_builder
.
Cast
(
C
,
"float16"
);
auto
E
=
net_builder
.
Cast
(
C
,
"float16"
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_3
)
{
int
h
=
128
,
w
=
128
;
NetBuilder
net_builder
(
"elementwise_3"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
D
=
net_builder
.
Cast
(
C
,
"float16"
);
auto
E
=
net_builder
.
Cast
(
C
,
"float16"
);
auto
F
=
net_builder
.
Cast
(
D
,
"float32"
);
auto
G
=
net_builder
.
Cast
(
E
,
"float32"
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_4
)
{
int
h
=
128
,
w
=
128
;
NetBuilder
net_builder
(
"elementwise_4"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
D
=
net_builder
.
Cast
(
C
,
"float16"
);
auto
E
=
net_builder
.
Cast
(
C
,
"float16"
);
auto
F
=
net_builder
.
Add
(
D
,
E
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_broadcast
)
{
NetBuilder
net_builder
(
"elementwise_broadcast"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
128
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
128
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
D
=
net_builder
.
BroadcastTo
(
C
,
{
128
,
128
});
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_double_broadcast
)
{
NetBuilder
net_builder
(
"elementwise_double_broadcast"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
128
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
128
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
D
=
net_builder
.
BroadcastTo
(
C
,
{
128
,
128
});
auto
E
=
net_builder
.
BroadcastTo
(
C
,
{
128
,
128
});
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
non_last_reduce_elementwise_1
)
{
int
h
=
128
,
w
=
128
;
NetBuilder
net_builder
(
"non_last_reduce_elementwise_1"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
0
});
auto
C
=
net_builder
.
Cast
(
B
,
"float16"
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
last_reduce_elementwise
)
{
NetBuilder
net_builder
(
"last_reduce_elementwise"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"C"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
128
,
64
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
1
});
auto
C
=
net_builder
.
CreateInput
(
Float
(
32
),
{
128
},
"C"
);
auto
D
=
net_builder
.
Add
(
B
,
C
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
keep_dim_reduce_elementwise_1
)
{
NetBuilder
net_builder
(
"keep_dim_reduce_elementwise"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"C"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
16
,
64
,
112
,
112
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
1
,
64
,
1
,
1
},
"B"
);
auto
C
=
net_builder
.
ReduceSum
(
A
,
{
0
,
2
,
3
},
true
);
auto
D
=
net_builder
.
Add
(
B
,
C
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
keep_dim_reduce_elementwise_2
)
{
NetBuilder
net_builder
(
"keep_dim_reduce_elementwise_2"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
16
,
64
,
112
,
112
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
16
,
64
,
1
,
1
},
"B"
);
auto
C
=
net_builder
.
ReduceSum
(
A
,
{
2
,
3
},
true
);
auto
D
=
net_builder
.
Add
(
B
,
C
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
keep_dim_reduce_elementwise_3
)
{
NetBuilder
net_builder
(
"keep_dim_reduce_elementwise_3"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
16
,
64
,
2048
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
16
,
64
,
1
},
"B"
);
auto
C
=
net_builder
.
ReduceSum
(
A
,
{
2
},
true
);
auto
D
=
net_builder
.
Add
(
B
,
C
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
keep_dim_reduce_elementwise_4
)
{
NetBuilder
net_builder
(
"keep_dim_reduce_elementwise_4"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
16
,
64
,
2048
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
16
,
1
,
2048
},
"B"
);
auto
C
=
net_builder
.
ReduceSum
(
A
,
{
1
},
true
);
auto
D
=
net_builder
.
Add
(
B
,
C
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
keep_dim_reduce_elementwise_5
)
{
NetBuilder
net_builder
(
"keep_dim_reduce_elementwise_5"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
16
,
64
,
16
,
1024
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
16
,
1
,
16
,
1
},
"B"
);
auto
C
=
net_builder
.
ReduceSum
(
A
,
{
1
,
3
},
true
);
auto
D
=
net_builder
.
Add
(
B
,
C
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_non_last_reduce
)
{
int
h
=
128
,
w
=
128
;
NetBuilder
net_builder
(
"elementwise_non_last_reduce"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
D
=
net_builder
.
ReduceSum
(
C
,
{
0
});
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_last_reduce
)
{
int
h
=
128
,
w
=
128
;
NetBuilder
net_builder
(
"elementwise_last_reduce"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"C"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
D
=
net_builder
.
ReduceSum
(
C
,
{
1
});
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_non_last_reduce_elementwise
)
{
int
h
=
128
,
w
=
128
;
NetBuilder
net_builder
(
"elementwise_non_last_reduce_elementwise"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
E
=
net_builder
.
ReduceSum
(
C
,
{
0
});
auto
F
=
net_builder
.
Cast
(
E
,
"float16"
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_last_reduce_elementwise
)
{
int
h
=
128
,
w
=
128
;
NetBuilder
net_builder
(
"elementwise_non_last_reduce_elementwise"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
E
=
net_builder
.
ReduceSum
(
C
,
{
1
});
auto
F
=
net_builder
.
Cast
(
E
,
"float16"
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_double_non_last_reduce_elementwise
)
{
int
h
=
128
,
w
=
128
;
NetBuilder
net_builder
(
"elementwise_double_non_last_reduce_elementwise"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
E
=
net_builder
.
ReduceSum
(
C
,
{
0
});
auto
F
=
net_builder
.
ReduceSum
(
C
,
{
0
});
auto
G
=
net_builder
.
Add
(
E
,
F
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
double_non_last_reduce_elementwise
)
{
int
h
=
128
,
w
=
128
;
NetBuilder
net_builder
(
"double_non_last_reduce_elementwise"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
*
2
,
w
},
"B"
);
auto
E
=
net_builder
.
ReduceSum
(
A
,
{
0
});
auto
F
=
net_builder
.
ReduceSum
(
B
,
{
0
});
auto
G
=
net_builder
.
Add
(
E
,
F
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
triple_non_last_reduce
)
{
int
h
=
128
,
w
=
1024
;
NetBuilder
net_builder
(
"triple_non_last_reduce"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
,
"B"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
128
,
1024
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
0
});
auto
C
=
net_builder
.
ReduceSum
(
A
,
{
0
});
auto
D
=
net_builder
.
ReduceSum
(
A
,
{
0
});
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
reduce_broadcast_1
)
{
int
h
=
32
,
w
=
32
;
NetBuilder
net_builder
(
"reduce_broadcast_1"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
*
w
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
0
});
auto
C
=
net_builder
.
BroadcastTo
(
B
,
{
h
*
w
},
{
0
});
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
reduce_broadcast_2
)
{
int
h
=
32
,
w
=
32
;
NetBuilder
net_builder
(
"reduce_broadcast_2"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
0
,
1
});
auto
C
=
net_builder
.
BroadcastTo
(
B
,
{
h
,
w
},
{
1
});
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
reduce_broadcast_3
)
{
int
h
=
32
,
w
=
32
;
NetBuilder
net_builder
(
"reduce_broadcast_3"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
h
,
w
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
1
,
2
});
auto
C
=
net_builder
.
BroadcastTo
(
B
,
{
h
,
h
,
w
},
{
0
});
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
reduce_broadcast_reduce_broadcast
)
{
int
h
=
32
,
w
=
32
;
NetBuilder
net_builder
(
"reduce_broadcast_reduce_broadcast"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
h
,
w
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
1
,
2
});
auto
C
=
net_builder
.
BroadcastTo
(
B
,
{
h
,
h
,
w
},
{
0
});
auto
D
=
net_builder
.
ReduceSum
(
C
,
{
1
,
2
});
auto
E
=
net_builder
.
BroadcastTo
(
D
,
{
h
,
h
,
w
},
{
0
});
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
reduce_broadcast_elementwise
)
{
int
h
=
32
,
w
=
32
;
NetBuilder
net_builder
(
"reduce_broadcast_elementwise"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
h
,
w
},
"A"
);
auto
B
=
net_builder
.
ReduceSum
(
A
,
{
1
,
2
});
auto
C
=
net_builder
.
BroadcastTo
(
B
,
{
h
,
h
,
w
},
{
0
});
auto
D
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"B"
);
auto
E
=
net_builder
.
BroadcastTo
(
D
,
{
h
,
h
,
w
},
{
1
,
2
});
auto
F
=
net_builder
.
Add
(
C
,
E
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_double_reduce_elementwise_1
)
{
NetBuilder
net_builder
(
"elementwise_double_reduce_elementwise_1"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
32
,
32
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
32
,
32
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
D
=
net_builder
.
ReduceSum
(
C
,
{
1
},
false
);
auto
E
=
net_builder
.
ReduceSum
(
C
,
{
1
},
false
);
auto
F
=
net_builder
.
Add
(
D
,
E
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
elementwise_double_reduce_elementwise_2
)
{
NetBuilder
net_builder
(
"elementwise_double_reduce_elementwise_2"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
1
,
1000
},
"A"
);
auto
B
=
net_builder
.
CreateInput
(
Float
(
32
),
{
1
,
1000
},
"B"
);
auto
C
=
net_builder
.
Add
(
A
,
B
);
auto
D
=
net_builder
.
ReduceSum
(
C
,
{
1
},
false
);
auto
E
=
net_builder
.
ReduceSum
(
C
,
{
1
},
false
);
auto
F
=
net_builder
.
Add
(
D
,
E
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
// Each of following unittests tests a group composed of typical operators
TEST
(
GROUP_SCHEDULER
,
layernorm
)
{
int
h
=
32
,
w
=
1024
;
NetBuilder
net_builder
(
"layernorm"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
// x
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
// x * x
auto
B
=
net_builder
.
Multiply
(
A
,
A
);
// sum x
auto
C
=
net_builder
.
ReduceSum
(
A
,
{
1
});
// sum x*x
auto
D
=
net_builder
.
ReduceSum
(
B
,
{
1
});
// constant w
auto
E
=
net_builder
.
FillConstant
<
float
>
({
h
},
1024.0
f
,
"E"
);
// mean
auto
F
=
net_builder
.
Divide
(
C
,
E
);
auto
FF
=
net_builder
.
BroadcastTo
(
F
,
{
h
,
w
},
{
0
});
// mean x*x
auto
G
=
net_builder
.
Divide
(
D
,
E
);
// mean * mean
auto
H
=
net_builder
.
Multiply
(
F
,
F
);
// var^2
auto
I
=
net_builder
.
Subtract
(
G
,
H
);
// eps
auto
J
=
net_builder
.
FillConstant
<
float
>
({
h
},
1e-10
f
,
"J"
);
// eps + delta
auto
K
=
net_builder
.
Add
(
I
,
J
);
// var
auto
L
=
net_builder
.
Sqrt
(
K
);
auto
LL
=
net_builder
.
BroadcastTo
(
L
,
{
h
,
w
},
{
0
});
// x - mean
auto
M
=
net_builder
.
Subtract
(
A
,
FF
);
// /var
auto
N
=
net_builder
.
Divide
(
M
,
LL
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
TEST
(
GROUP_SCHEDULER
,
softmax
)
{
int
h
=
32
,
w
=
1024
;
NetBuilder
net_builder
(
"softmax"
);
std
::
vector
<
std
::
string
>
input_names
=
{
"A"
};
// create model
auto
CreateModel
=
[
&
]()
{
// softmax
auto
A
=
net_builder
.
CreateInput
(
Float
(
32
),
{
h
,
w
},
"A"
);
// reduce max
auto
B
=
net_builder
.
ReduceMax
(
A
,
{
1
});
// broadcast
auto
C
=
net_builder
.
BroadcastTo
(
B
,
{
h
,
w
},
{
0
});
// x - max(x)
auto
D
=
net_builder
.
Subtract
(
A
,
C
);
// exp(x)
auto
E
=
net_builder
.
Exp
(
D
);
// reduce sum
auto
F
=
net_builder
.
ReduceSum
(
E
,
{
1
});
// broadcast
auto
G
=
net_builder
.
BroadcastTo
(
F
,
{
h
,
w
},
{
0
});
// exp(x)/sum(exp(x))
auto
H
=
net_builder
.
Divide
(
E
,
G
);
};
CreateModel
();
Compile
(
&
net_builder
);
CreateModel
();
CheckAccuracy
(
&
net_builder
,
input_names
);
}
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/test/tensor_test.cc
View file @
01a10755
...
...
@@ -20,8 +20,8 @@
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/test_helper.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
...
...
paddle/cinn/ir/utils/CMakeLists.txt
View file @
01a10755
...
...
@@ -3,10 +3,8 @@ core_gather_headers()
gather_srcs
(
cinnapi_src
SRCS
ir_visitor.cc
ir_mutator.cc
ir_printer.cc
ir_verify.cc
ir_compare.cc
ir_nodes_collector.cc
ir_copy.cc
)
ir_copy.cc
ir_replace.cc
)
paddle/cinn/ir/utils/ir_compare.cc
View file @
01a10755
...
...
@@ -17,16 +17,22 @@
#include <regex>
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace
cinn
{
namespace
ir
{
namespace
ir_utils
{
bool
IrEqualVisitor
::
Compare
(
const
Expr
&
lhs
,
const
Expr
&
rhs
)
{
if
(
lhs
.
get
()
==
rhs
.
get
())
{
// the same object, including both are null
return
true
;
}
if
(
only_compare_structure_
&&
!
lhs
.
defined
()
&&
!
rhs
.
defined
())
{
return
true
;
}
if
(
!
lhs
.
defined
()
||
!
rhs
.
defined
())
{
// someone invalid
return
false
;
VLOG
(
5
)
<<
"Not equal on Expr, someone not defined"
;
...
...
@@ -44,10 +50,9 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
return
equal
;
}
bool
IrEqualVisitor
::
Compare
(
const
std
::
string
&
lhs
,
const
std
::
string
&
rhs
,
bool
allow_name_suffix_diff
)
{
// if allow_name_suffix_diff=true then just compare the name prefix before the
bool
IrEqualVisitor
::
Compare
(
const
std
::
string
&
lhs
,
const
std
::
string
&
rhs
)
{
// if allow_name_suffix_diff_=true then just compare the name prefix before
// the
// "_[0-9]+"
auto
common_len
=
0
;
for
(;
common_len
<
lhs
.
size
()
&&
common_len
<
rhs
.
size
();
++
common_len
)
{
...
...
@@ -65,7 +70,7 @@ bool IrEqualVisitor::Compare(const std::string& lhs,
equal
=
true
;
}
else
{
equal
=
false
;
if
(
allow_name_suffix_diff
)
{
if
(
allow_name_suffix_diff
_
)
{
equal
=
is_endswith_index
(
lhs
)
&&
is_endswith_index
(
rhs
);
}
}
...
...
@@ -179,17 +184,26 @@ bool IrEqualVisitor::Visit(const Block* lhs, const Expr* other) {
bool
IrEqualVisitor
::
Visit
(
const
Call
*
lhs
,
const
Expr
*
other
)
{
auto
*
rhs
=
other
->
As
<
Call
>
();
return
lhs
->
name
==
rhs
->
name
&&
Compare
(
lhs
->
read_args
,
rhs
->
read_args
)
&&
bool
flag
=
Compare
(
lhs
->
read_args
,
rhs
->
read_args
)
&&
Compare
(
lhs
->
write_args
,
rhs
->
write_args
)
&&
Compare
(
lhs
->
attrs
,
rhs
->
attrs
)
&&
lhs
->
call_type
==
rhs
->
call_type
;
Compare
(
lhs
->
attrs
,
rhs
->
attrs
)
&&
lhs
->
call_type
==
rhs
->
call_type
;
if
(
only_compare_structure_
)
{
return
flag
;
}
return
lhs
->
name
==
rhs
->
name
&&
flag
;
// TODO(CtfGo): Compare `func` field
}
bool
IrEqualVisitor
::
Visit
(
const
_Var_
*
lhs
,
const
Expr
*
other
)
{
auto
*
rhs
=
other
->
As
<
_Var_
>
();
return
lhs
->
name
==
rhs
->
name
&&
Compare
(
lhs
->
lower_bound
,
rhs
->
lower_bound
)
&&
Compare
(
lhs
->
upper_bound
,
rhs
->
upper_bound
)
&&
lhs
->
tag
==
rhs
->
tag
;
bool
flag
=
Compare
(
lhs
->
lower_bound
,
rhs
->
lower_bound
)
&&
Compare
(
lhs
->
upper_bound
,
rhs
->
upper_bound
)
&&
lhs
->
tag
==
rhs
->
tag
;
if
(
only_compare_structure_
)
{
return
flag
;
}
return
lhs
->
name
==
rhs
->
name
&&
flag
;
}
bool
IrEqualVisitor
::
Visit
(
const
Load
*
lhs
,
const
Expr
*
other
)
{
...
...
@@ -219,19 +233,25 @@ bool IrEqualVisitor::Visit(const Free* lhs, const Expr* other) {
bool
IrEqualVisitor
::
Visit
(
const
_Buffer_
*
lhs
,
const
Expr
*
other
)
{
auto
*
rhs
=
other
->
As
<
_Buffer_
>
();
return
Compare
(
lhs
->
shape
,
rhs
->
shape
)
&&
Compare
(
lhs
->
strides
,
rhs
->
strides
)
&&
lhs
->
name
==
rhs
->
name
&&
lhs
->
scope
==
rhs
->
scope
&&
Compare
(
lhs
->
elem_offset
,
rhs
->
elem_offset
)
&&
lhs
->
offset_factor
==
rhs
->
offset_factor
&&
lhs
->
target
==
rhs
->
target
&&
bool
flag
=
Compare
(
lhs
->
shape
,
rhs
->
shape
)
&&
Compare
(
lhs
->
strides
,
rhs
->
strides
)
&&
lhs
->
scope
==
rhs
->
scope
&&
Compare
(
lhs
->
elem_offset
,
rhs
->
elem_offset
)
&&
lhs
->
offset_factor
==
rhs
->
offset_factor
&&
lhs
->
target
==
rhs
->
target
&&
lhs
->
data_alignment
==
rhs
->
data_alignment
&&
lhs
->
memory_type
==
rhs
->
memory_type
&&
lhs
->
dtype
==
rhs
->
dtype
;
if
(
only_compare_structure_
)
{
return
flag
;
}
return
flag
&&
lhs
->
name
==
rhs
->
name
;
}
bool
IrEqualVisitor
::
Visit
(
const
_Tensor_
*
lhs
,
const
Expr
*
other
)
{
auto
*
rhs
=
other
->
As
<
_Tensor_
>
();
return
lhs
->
name
==
rhs
->
name
&&
Compare
(
lhs
->
shape
,
rhs
->
shape
);
bool
flag
=
Compare
(
lhs
->
shape
,
rhs
->
shape
);
if
(
only_compare_structure_
)
{
return
flag
;
}
return
flag
&&
Compare
(
lhs
->
name
,
rhs
->
name
);
}
bool
IrEqualVisitor
::
Visit
(
const
_LoweredFunc_
*
lhs
,
const
Expr
*
other
)
{
...
...
@@ -280,10 +300,15 @@ bool IrEqualVisitor::Visit(const _LoweredFunc_* lhs, const Expr* other) {
bool
IrEqualVisitor
::
Visit
(
const
_Module_
*
lhs
,
const
Expr
*
other
)
{
auto
*
rhs
=
other
->
As
<
_Module_
>
();
return
lhs
->
name
==
rhs
->
name
&&
lhs
->
target
==
rhs
->
target
&&
Compare
(
lhs
->
buffers
,
rhs
->
buffers
)
&&
bool
flag
=
Compare
(
lhs
->
buffers
,
rhs
->
buffers
)
&&
Compare
(
lhs
->
functions
,
rhs
->
functions
)
&&
Compare
(
lhs
->
submodules
,
rhs
->
submodules
);
if
(
only_compare_structure_
)
{
return
flag
;
}
return
flag
&&
lhs
->
name
==
rhs
->
name
;
}
bool
IrEqualVisitor
::
Visit
(
const
Let
*
lhs
,
const
Expr
*
other
)
{
...
...
@@ -345,11 +370,16 @@ bool IrEqualVisitor::Visit(const _BufferRange_* lhs, const Expr* other) {
bool
IrEqualVisitor
::
Visit
(
const
ScheduleBlock
*
lhs
,
const
Expr
*
other
)
{
auto
*
rhs
=
other
->
As
<
ScheduleBlock
>
();
return
Compare
(
lhs
->
name
,
rhs
->
name
,
allow_name_suffix_diff_
)
&&
Compare
(
lhs
->
iter_vars
,
rhs
->
iter_vars
)
&&
bool
flag
=
Compare
(
lhs
->
iter_vars
,
rhs
->
iter_vars
)
&&
Compare
(
lhs
->
read_buffers
,
rhs
->
read_buffers
)
&&
Compare
(
lhs
->
write_buffers
,
rhs
->
write_buffers
)
&&
Compare
(
lhs
->
attrs
,
rhs
->
attrs
)
&&
Compare
(
lhs
->
body
,
rhs
->
body
);
Compare
(
lhs
->
body
,
rhs
->
body
);
if
(
only_compare_structure_
)
{
return
flag
;
}
return
flag
&&
Compare
(
lhs
->
attrs
,
rhs
->
attrs
)
&&
Compare
(
lhs
->
name
,
rhs
->
name
);
}
bool
IrEqualVisitor
::
Visit
(
const
ScheduleBlockRealize
*
lhs
,
const
Expr
*
other
)
{
...
...
@@ -358,5 +388,18 @@ bool IrEqualVisitor::Visit(const ScheduleBlockRealize* lhs, const Expr* other) {
Compare
(
lhs
->
schedule_block
,
rhs
->
schedule_block
);
}
bool
IrEqualVisitor
::
Visit
(
const
_Dim_
*
lhs
,
const
Expr
*
other
)
{
auto
*
rhs
=
other
->
As
<
_Dim_
>
();
return
lhs
->
name
==
rhs
->
name
&&
lhs
->
GetSymbolName
()
==
rhs
->
GetSymbolName
()
&&
lhs
->
GetRealDimSize
()
==
rhs
->
GetRealDimSize
();
}
bool
IRCompare
(
const
Expr
&
lhs
,
const
Expr
&
rhs
,
bool
allow_name_suffix_diff
)
{
IrEqualVisitor
ir_equal_visitor
(
allow_name_suffix_diff
);
return
ir_equal_visitor
.
Compare
(
lhs
,
rhs
);
}
}
// namespace ir_utils
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/utils/ir_compare.h
View file @
01a10755
...
...
@@ -16,24 +16,25 @@
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/
utils/
ir_visitor.h"
#include "paddle/cinn/ir/ir_visitor.h"
namespace
cinn
{
namespace
ir
{
namespace
ir_utils
{
// Determine whether two ir AST trees are euqal by comparing their struct and
// fields of each node through dfs visitor
class
IrEqualVisitor
:
public
IRVisitorRequireReImpl
<
bool
,
const
Expr
*>
{
public:
explicit
IrEqualVisitor
(
bool
allow_name_suffix_diff
=
false
)
:
allow_name_suffix_diff_
(
allow_name_suffix_diff
)
{}
explicit
IrEqualVisitor
(
bool
allow_name_suffix_diff
=
false
,
bool
only_compare_structure
=
false
)
:
allow_name_suffix_diff_
(
allow_name_suffix_diff
),
only_compare_structure_
(
only_compare_structure
)
{}
// Return true if they are euqal, otherwise false;
bool
Compare
(
const
Expr
&
lhs
,
const
Expr
&
rhs
);
private:
bool
Compare
(
const
std
::
string
&
lhs
,
const
std
::
string
&
rhs
,
bool
allow_name_suffix_diff
=
false
);
bool
Compare
(
const
std
::
string
&
lhs
,
const
std
::
string
&
rhs
);
bool
Compare
(
const
std
::
map
<
std
::
string
,
attr_t
>&
lhs
,
const
std
::
map
<
std
::
string
,
attr_t
>&
rhs
);
template
<
typename
T
>
...
...
@@ -45,7 +46,14 @@ class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {
// whether allowing name suffix ends with "_[0-9]+" different
bool
allow_name_suffix_diff_
=
false
;
// not compare name field of Expr
bool
only_compare_structure_
=
false
;
};
bool
IRCompare
(
const
Expr
&
lhs
,
const
Expr
&
rhs
,
bool
allow_name_suffix_diff
=
false
);
}
// namespace ir_utils
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/utils/ir_copy.cc
View file @
01a10755
...
...
@@ -21,15 +21,15 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
optim
{
using
namespace
ir
;
// NOLINT
namespace
ir
{
namespace
ir
_utils
{
namespace
{
struct
IRCopyVisitor
:
public
ir
::
IRVisitorRequireReImpl
<
Expr
>
{
// Use maps to unify all the copied tensors and buffers.
std
::
map
<
std
::
string
,
ir
::
_Tensor_
*>
tensor_map
;
...
...
@@ -241,6 +241,7 @@ struct IRCopyVisitor : public ir::IRVisitorRequireReImpl<Expr> {
std
::
vector
<
Expr
>
buffers
;
std
::
vector
<
Expr
>
functions
;
std
::
vector
<
Expr
>
submodules
;
std
::
vector
<
Expr
>
predicates
;
for
(
auto
&
expr
:
op
->
buffers
)
{
buffers
.
push_back
(
Visit
(
&
expr
));
...
...
@@ -254,10 +255,15 @@ struct IRCopyVisitor : public ir::IRVisitorRequireReImpl<Expr> {
submodules
.
push_back
(
Visit
(
&
expr
));
}
for
(
auto
&
expr
:
op
->
predicates
)
{
predicates
.
push_back
(
Visit
(
&
expr
));
}
auto
res
=
ir
::
_Module_
::
Make
(
op
->
name
,
op
->
target
);
res
->
buffers
=
buffers
;
res
->
functions
=
functions
;
res
->
submodules
=
submodules
;
res
->
predicates
=
predicates
;
return
Expr
(
res
);
}
...
...
@@ -407,6 +413,10 @@ struct IRCopyVisitor : public ir::IRVisitorRequireReImpl<Expr> {
Visit
(
&
op
->
schedule_block
));
}
Expr
Visit
(
const
ir
::
_Dim_
*
op
)
override
{
return
ir
::
_Dim_
::
Make
(
op
->
name
,
op
->
sym_dim
);
}
#define __(x__) Expr Visit(const ir::intrinsics::x__* op);
INTRINSIC_KIND_FOR_EACH
(
__
)
#undef __
...
...
@@ -474,7 +484,7 @@ Expr IRCopyVisitor::Visit(const ir::intrinsics::BuiltinIntrin* op) {
return
intrinsics
::
BuiltinIntrin
::
Make
(
op
->
name
,
op
->
args
,
op
->
id
,
op
->
arg_nums
,
op
->
type
());
}
}
// namespace
Expr
IRCopy
(
Expr
x
)
{
IRCopyVisitor
visitor
;
auto
copied
=
visitor
.
Visit
(
&
x
);
...
...
@@ -507,6 +517,6 @@ std::vector<ir::LoweredFunc> IRCopy(const std::vector<ir::LoweredFunc>& x) {
}
return
res
;
}
}
// namespace
optim
}
// namespace ir_utils
}
// namespace
ir
}
// namespace cinn
paddle/cinn/ir/utils/ir_copy.h
View file @
01a10755
...
...
@@ -24,9 +24,8 @@ namespace cinn {
namespace
ir
{
class
ModuleExpr
;
}
// namespace ir
namespace
optim
{
namespace
ir_utils
{
//! Shallow copy an expression.
Expr
IRCopy
(
Expr
x
);
...
...
@@ -39,5 +38,6 @@ ir::LoweredFunc IRCopy(const ir::LoweredFunc& x);
std
::
vector
<
ir
::
LoweredFunc
>
IRCopy
(
const
std
::
vector
<
ir
::
LoweredFunc
>&
x
);
}
// namespace optim
}
// namespace ir_utils
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/utils/ir_nodes_collector.cc
View file @
01a10755
...
...
@@ -15,14 +15,14 @@
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include <glog/logging.h>
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace
cinn
{
namespace
ir
{
namespace
ir_utils
{
namespace
{
struct
IrNodesCollector
:
public
IRVisitorRequireReImpl
<
void
>
{
using
teller_t
=
std
::
function
<
bool
(
const
Expr
*
)
>
;
using
handler_t
=
std
::
function
<
void
(
const
Expr
*
)
>
;
...
...
@@ -207,5 +207,116 @@ std::set<Expr> CollectReferencedTensors(
return
ts0
;
}
std
::
vector
<
std
::
string
>
CollectUndefinedVars
(
const
Expr
*
e
)
{
struct
Mutator
:
public
ir
::
IRMutator
<
const
Expr
*>
{
using
ir
::
IRMutator
<
const
Expr
*>::
Visit
;
std
::
vector
<
std
::
string
>
undefined_vars
;
std
::
set
<
std
::
string
>
defined_vars
;
std
::
set
<
std
::
string
>
used_vars
;
void
CollectVarDef
(
const
std
::
string
&
var
)
{
CHECK
(
!
defined_vars
.
count
(
var
))
<<
"var "
<<
var
<<
" has been defined, please check"
;
CHECK
(
!
used_vars
.
count
(
var
))
<<
"var "
<<
var
<<
" is wrongly used before definition"
;
defined_vars
.
insert
(
var
);
}
void
ClearVar
(
const
std
::
string
&
var
)
{
defined_vars
.
erase
(
var
);
used_vars
.
erase
(
var
);
}
void
CollectVarUse
(
const
std
::
string
&
var
)
{
used_vars
.
insert
(
var
);
if
(
defined_vars
.
count
(
var
)
==
0
)
{
undefined_vars
.
push_back
(
var
);
}
}
void
Visit
(
const
ir
::
Let
*
op
,
const
Expr
*
expr
)
override
{
Expr
symbol
=
op
->
symbol
;
auto
var
=
symbol
.
as_var_ref
();
CHECK
(
var
.
defined
());
CollectVarDef
(
var
->
name
);
auto
*
node
=
expr
->
As
<
ir
::
Let
>
();
Visit
(
&
node
->
body
,
&
node
->
body
);
}
void
Visit
(
const
ir
::
For
*
op
,
const
Expr
*
expr
)
override
{
CollectVarDef
(
op
->
loop_var
->
name
);
auto
*
node
=
expr
->
As
<
ir
::
For
>
();
Visit
(
&
node
->
min
,
&
node
->
min
);
Visit
(
&
node
->
extent
,
&
node
->
extent
);
Visit
(
&
node
->
body
,
&
node
->
body
);
ClearVar
(
op
->
loop_var
->
name
);
}
void
Visit
(
const
ir
::
Load
*
op
,
const
Expr
*
expr
)
override
{
auto
tensor
=
op
->
tensor
.
as_tensor_ref
();
CollectVarUse
(
tensor
->
name
);
auto
*
node
=
expr
->
As
<
ir
::
Load
>
();
for
(
auto
&
idx
:
node
->
indices
)
Visit
(
&
idx
,
&
idx
);
}
void
Visit
(
const
ir
::
Store
*
op
,
const
Expr
*
expr
)
override
{
auto
tensor
=
op
->
tensor
.
as_tensor_ref
();
CollectVarUse
(
tensor
->
name
);
auto
*
node
=
expr
->
As
<
ir
::
Store
>
();
for
(
auto
&
idx
:
node
->
indices
)
Visit
(
&
idx
,
&
idx
);
Visit
(
&
node
->
value
,
&
node
->
value
);
}
void
Visit
(
const
ir
::
_Var_
*
op
,
const
Expr
*
expr
)
override
{
CollectVarUse
(
op
->
name
);
auto
*
node
=
expr
->
As
<
ir
::
_Var_
>
();
if
(
node
->
lower_bound
.
defined
())
{
Visit
(
&
node
->
lower_bound
,
&
node
->
lower_bound
);
}
if
(
node
->
upper_bound
.
defined
())
{
Visit
(
&
node
->
upper_bound
,
&
node
->
upper_bound
);
}
}
void
Visit
(
const
ir
::
Reduce
*
op
,
const
Expr
*
expr
)
override
{
for
(
auto
&
axis
:
op
->
reduce_axis
)
{
CollectVarDef
(
axis
->
name
);
}
auto
*
node
=
expr
->
As
<
ir
::
Reduce
>
();
if
(
node
->
init
.
defined
())
Visit
(
&
node
->
init
,
&
node
->
init
);
Visit
(
&
node
->
body
,
&
node
->
body
);
}
};
Mutator
mutator
;
mutator
.
Visit
(
e
,
e
);
return
mutator
.
undefined_vars
;
}
std
::
set
<
std
::
string
>
CollectTensorNeedsWrite
(
const
Expr
*
e
)
{
std
::
set
<
std
::
string
>
tensor_written
;
IrNodesCollector
::
handler_t
handler
=
[
&
](
const
Expr
*
x
)
{
if
(
x
->
As
<
ir
::
Store
>
())
{
tensor_written
.
insert
(
x
->
As
<
ir
::
Store
>
()
->
tensor
.
As
<
ir
::
_Tensor_
>
()
->
name
);
}
if
(
x
->
As
<
ir
::
_Tensor_
>
())
{
tensor_written
.
insert
(
x
->
As
<
ir
::
_Tensor_
>
()
->
name
);
}
};
IrNodesCollector
::
teller_t
teller
=
[](
const
Expr
*
x
)
{
if
(
x
->
As
<
ir
::
Store
>
()
&&
x
->
As
<
ir
::
Store
>
()
->
tensor
.
As
<
ir
::
_Tensor_
>
())
{
return
true
;
}
if
(
x
->
As
<
ir
::
_Tensor_
>
()
&&
x
->
As
<
ir
::
_Tensor_
>
()
->
is_call_node
())
{
return
true
;
}
return
false
;
};
IrNodesCollector
collector
(
std
::
move
(
teller
),
std
::
move
(
handler
),
false
);
collector
.
Visit
(
e
);
return
tensor_written
;
}
}
// namespace ir_utils
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/utils/ir_nodes_collector.h
View file @
01a10755
...
...
@@ -18,7 +18,7 @@
namespace
cinn
{
namespace
ir
{
namespace
ir_utils
{
/**
* Collect the IR Nodes(without duplication) in the expression.
*/
...
...
@@ -65,5 +65,24 @@ std::map<std::string, Expr> CollectTensorMap(
return
true
;
});
/**
* Collect undefined vars in the scope.
*
* e.g.
*
* The expression:
* for i
* for j
* a[i, j] = b[i, j]
*
* here a, b are vars without definition
*/
std
::
vector
<
std
::
string
>
CollectUndefinedVars
(
const
Expr
*
e
);
/**
* Collect the Tensor Nodes which will be Writed by Store or Call Nodes
*/
std
::
set
<
std
::
string
>
CollectTensorNeedsWrite
(
const
Expr
*
e
);
}
// namespace ir_utils
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/utils/ir_printer.cc
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/utils/ir_printer.h"
#include <algorithm>
#include <iomanip>
#include <limits>
#include <vector>
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
namespace
ir
{
using
common
::
bfloat16
;
using
common
::
float16
;
void
IrPrinter
::
Print
(
const
Expr
&
e
)
{
IRVisitorRequireReImpl
::
Visit
(
&
e
);
os_
<<
str_
;
str_
=
""
;
}
void
IrPrinter
::
Print
(
const
std
::
vector
<
Expr
>
&
exprs
,
const
std
::
string
&
splitter
)
{
for
(
std
::
size_t
i
=
0
;
!
exprs
.
empty
()
&&
i
+
1
<
exprs
.
size
();
i
++
)
{
Visit
(
exprs
[
i
]);
str_
+=
splitter
;
}
if
(
!
exprs
.
empty
())
Visit
(
exprs
.
back
());
os_
<<
str_
;
str_
=
""
;
}
void
IrPrinter
::
Visit
(
const
IntImm
*
x
)
{
if
(
x
->
type
().
is_int
(
64
))
{
str_
+=
std
::
to_string
(
x
->
value
);
str_
+=
"ll"
;
}
else
if
(
x
->
type
().
is_int
(
32
))
{
str_
+=
std
::
to_string
(
x
->
value
);
}
else
if
(
x
->
type
().
is_int
(
16
))
{
str_
+=
"(int16_t)"
;
str_
+=
std
::
to_string
(
x
->
value
);
}
else
if
(
x
->
type
().
is_int
(
8
))
{
str_
+=
"(int8_t)"
;
str_
+=
std
::
to_string
(
x
->
value
);
}
else
{
LOG
(
FATAL
)
<<
"Not support int type: "
<<
x
->
type
();
}
}
void
IrPrinter
::
Visit
(
const
UIntImm
*
x
)
{
if
(
x
->
type
().
is_uint
(
64
))
{
str_
+=
std
::
to_string
(
x
->
value
);
str_
+=
"ull"
;
}
else
if
(
x
->
type
().
is_uint
(
32
))
{
str_
+=
std
::
to_string
(
x
->
value
);
}
else
if
(
x
->
type
().
is_uint
(
16
))
{
str_
+=
"(uint16_t)"
;
str_
+=
std
::
to_string
(
x
->
value
);
}
else
if
(
x
->
type
().
is_uint
(
8
))
{
str_
+=
"(uint8_t)"
;
str_
+=
std
::
to_string
(
x
->
value
);
}
else
if
(
x
->
type
().
is_uint
(
1
))
{
if
(
x
->
value
)
{
str_
+=
"true"
;
}
else
{
str_
+=
"false"
;
}
}
else
{
LOG
(
FATAL
)
<<
"Not support uint type: "
<<
x
->
type
();
}
}
void
IrPrinter
::
Visit
(
const
FloatImm
*
x
)
{
std
::
ostringstream
ss
;
if
(
x
->
type
().
is_float16
())
{
if
(
std
::
isinf
(
x
->
value
))
{
ss
<<
"cinn::common::raw_uint16_to_float16(0x7c00)"
;
}
else
if
(
std
::
isnan
(
x
->
value
))
{
ss
<<
"cinn::common::raw_uint16_to_float16(0x7e00)"
;
}
else
{
ss
<<
"(float16)"
;
ss
<<
std
::
setprecision
(
std
::
numeric_limits
<
float16
>::
max_digits10
);
ss
<<
static_cast
<
float16
>
(
x
->
value
)
<<
"f"
;
}
}
else
if
(
x
->
type
().
is_bfloat16
())
{
if
(
std
::
isinf
(
x
->
value
))
{
ss
<<
"cinn::common::raw_uint16_to_bfloat16(0x7F80)"
;
}
else
if
(
std
::
isnan
(
x
->
value
))
{
ss
<<
"cinn::common::raw_uint16_to_bfloat16(0x7FC0)"
;
}
else
{
ss
<<
"(bfloat16)"
;
ss
<<
std
::
setprecision
(
std
::
numeric_limits
<
bfloat16
>::
max_digits10
);
ss
<<
static_cast
<
bfloat16
>
(
x
->
value
)
<<
"f"
;
}
}
else
if
(
x
->
type
().
is_float
(
32
))
{
ss
<<
std
::
setprecision
(
std
::
numeric_limits
<
float
>::
max_digits10
);
ss
<<
std
::
showpoint
;
ss
<<
x
->
value
;
if
(
std
::
isfinite
(
x
->
value
))
{
ss
<<
"f"
;
}
}
else
if
(
x
->
type
().
is_float
(
64
))
{
ss
<<
std
::
setprecision
(
std
::
numeric_limits
<
double
>::
max_digits10
);
ss
<<
std
::
showpoint
;
ss
<<
x
->
value
;
}
else
{
LOG
(
FATAL
)
<<
"Not support float type: "
<<
x
->
type
();
}
str_
+=
ss
.
str
();
}
void
IrPrinter
::
Visit
(
const
StringImm
*
x
)
{
str_
+=
"
\"
"
;
str_
+=
x
->
value
;
str_
+=
"
\"
"
;
}
void
IrPrinter
::
Visit
(
const
Add
*
x
)
{
PrintBinaryOp
(
"+"
,
x
);
}
void
IrPrinter
::
Visit
(
const
Sub
*
x
)
{
PrintBinaryOp
(
"-"
,
x
);
}
void
IrPrinter
::
Visit
(
const
Mul
*
x
)
{
PrintBinaryOp
(
"*"
,
x
);
}
void
IrPrinter
::
Visit
(
const
Div
*
x
)
{
PrintBinaryOp
(
"/"
,
x
);
}
void
IrPrinter
::
Visit
(
const
Mod
*
x
)
{
PrintBinaryOp
(
"%"
,
x
);
}
void
IrPrinter
::
Visit
(
const
EQ
*
x
)
{
PrintBinaryOp
(
"=="
,
x
);
}
void
IrPrinter
::
Visit
(
const
NE
*
x
)
{
PrintBinaryOp
(
"!="
,
x
);
}
void
IrPrinter
::
Visit
(
const
LT
*
x
)
{
PrintBinaryOp
(
"<"
,
x
);
}
void
IrPrinter
::
Visit
(
const
LE
*
x
)
{
PrintBinaryOp
(
"<="
,
x
);
}
void
IrPrinter
::
Visit
(
const
GT
*
x
)
{
PrintBinaryOp
(
">"
,
x
);
}
void
IrPrinter
::
Visit
(
const
GE
*
x
)
{
PrintBinaryOp
(
">="
,
x
);
}
void
IrPrinter
::
Visit
(
const
And
*
x
)
{
PrintBinaryOp
(
"and"
,
x
);
}
void
IrPrinter
::
Visit
(
const
Or
*
x
)
{
PrintBinaryOp
(
"or"
,
x
);
}
void
IrPrinter
::
Visit
(
const
Not
*
x
)
{
str_
+=
"!"
;
Visit
(
x
->
v
());
}
void
IrPrinter
::
Visit
(
const
Min
*
x
)
{
str_
+=
"cinn_min("
;
Visit
(
x
->
a
());
str_
+=
", "
;
Visit
(
x
->
b
());
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
Max
*
x
)
{
str_
+=
"cinn_max("
;
Visit
(
x
->
a
());
str_
+=
", "
;
Visit
(
x
->
b
());
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
Minus
*
x
)
{
str_
+=
"-("
;
Visit
(
x
->
v
());
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
For
*
x
)
{
if
(
x
->
is_parallel
())
{
str_
+=
"parallel for ("
;
}
else
if
(
x
->
is_unrolled
())
{
str_
+=
"unroll for ("
;
}
else
if
(
x
->
is_vectorized
())
{
int
factor
=
x
->
vectorize_info
().
factor
;
str_
+=
"vectorize["
;
str_
+=
std
::
to_string
(
factor
);
str_
+=
"] for ("
;
}
else
if
(
x
->
is_binded
())
{
auto
&
bind_info
=
x
->
bind_info
();
if
(
bind_info
.
valid
())
{
char
axis_name
=
'x'
+
bind_info
.
offset
;
auto
for_type
=
bind_info
.
for_type
;
std
::
string
prefix
=
for_type
==
ForType
::
GPUBlock
?
"blockIdx."
:
"threadIdx."
;
str_
+=
"thread_bind["
;
str_
+=
prefix
;
str_
+=
axis_name
;
str_
+=
"] for ("
;
}
else
{
str_
+=
"thread_bind[invalid info] for ("
;
}
}
else
if
(
x
->
is_serial
())
{
str_
+=
"serial for ("
;
}
else
if
(
x
->
is_default
())
{
str_
+=
"default for ("
;
}
else
{
str_
+=
"for ("
;
}
Visit
(
x
->
loop_var
);
str_
+=
", "
;
Visit
(
x
->
min
);
str_
+=
", "
;
Visit
(
x
->
extent
);
str_
+=
")
\n
"
;
DoIndent
();
Visit
(
x
->
body
);
}
void
IrPrinter
::
Visit
(
const
PolyFor
*
x
)
{
if
(
x
->
is_parallel
())
{
str_
+=
"parallel poly_for ("
;
}
else
{
str_
+=
"poly_for ("
;
}
Visit
(
x
->
iterator
);
str_
+=
", "
;
Visit
(
x
->
init
);
str_
+=
", "
;
Visit
(
x
->
condition
);
str_
+=
", "
;
Visit
(
x
->
inc
);
str_
+=
")
\n
"
;
DoIndent
();
Visit
(
x
->
body
);
}
void
IrPrinter
::
Visit
(
const
IfThenElse
*
x
)
{
str_
+=
"if ("
;
Visit
(
x
->
condition
);
str_
+=
") {
\n
"
;
IncIndent
();
DoIndent
();
Visit
(
x
->
true_case
);
DecIndent
();
str_
+=
"
\n
"
;
DoIndent
();
str_
+=
"}"
;
if
(
x
->
false_case
.
defined
())
{
str_
+=
" else {
\n
"
;
IncIndent
();
DoIndent
();
Visit
(
x
->
false_case
);
str_
+=
"
\n
"
;
DecIndent
();
DoIndent
();
str_
+=
"}"
;
}
}
void
IrPrinter
::
Visit
(
const
Block
*
x
)
{
str_
+=
"{
\n
"
;
IncIndent
();
for
(
std
::
size_t
i
=
0
;
!
x
->
stmts
.
empty
()
&&
i
+
1
<
x
->
stmts
.
size
();
i
++
)
{
DoIndent
();
Visit
(
x
->
stmts
[
i
]);
str_
+=
"
\n
"
;
}
if
(
!
x
->
stmts
.
empty
())
{
DoIndent
();
Visit
(
x
->
stmts
.
back
());
}
DecIndent
();
str_
+=
"
\n
"
;
DoIndent
();
str_
+=
"}"
;
}
void
IrPrinter
::
Visit
(
const
Call
*
x
)
{
str_
+=
x
->
name
;
str_
+=
"("
;
if
(
!
x
->
read_args
.
empty
())
{
for
(
std
::
size_t
i
=
0
;
i
+
1
<
x
->
read_args
.
size
();
i
++
)
{
Visit
(
x
->
read_args
[
i
]);
str_
+=
", "
;
}
Visit
(
x
->
read_args
.
back
());
}
if
(
!
x
->
write_args
.
empty
())
{
if
(
!
x
->
read_args
.
empty
())
str_
+=
", "
;
for
(
std
::
size_t
i
=
0
;
i
+
1
<
x
->
write_args
.
size
();
i
++
)
{
Visit
(
x
->
write_args
[
i
]);
str_
+=
", "
;
}
Visit
(
x
->
write_args
.
back
());
}
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
Cast
*
x
)
{
str_
+=
x
->
type
().
to_string
();
str_
+=
"("
;
Visit
(
x
->
v
());
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
_Module_
*
x
)
{}
void
IrPrinter
::
Visit
(
const
_Var_
*
x
)
{
str_
+=
x
->
name
;
}
void
IrPrinter
::
Visit
(
const
Alloc
*
x
)
{
auto
*
buffer
=
x
->
destination
.
As
<
ir
::
_Buffer_
>
();
CHECK
(
buffer
);
str_
+=
"alloc("
;
str_
+=
buffer
->
name
;
str_
+=
", "
;
Visit
(
x
->
extents
);
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
Select
*
x
)
{
str_
+=
"select("
;
Visit
(
x
->
condition
);
str_
+=
", "
;
Visit
(
x
->
true_value
);
str_
+=
", "
;
Visit
(
x
->
false_value
);
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
Load
*
x
)
{
if
(
x
->
is_addr_tensor
())
{
auto
*
tensor
=
x
->
tensor
.
As
<
ir
::
_Tensor_
>
();
CHECK
(
tensor
);
str_
+=
tensor
->
name
;
}
else
if
(
x
->
is_addr_scalar
())
{
Visit
(
x
->
tensor
);
}
else
{
CINN_NOT_IMPLEMENTED
}
str_
+=
"["
;
for
(
std
::
size_t
i
=
0
;
i
+
1
<
x
->
indices
.
size
();
i
++
)
{
Visit
(
x
->
indices
[
i
]);
str_
+=
", "
;
}
if
(
!
x
->
indices
.
empty
())
Visit
(
x
->
indices
.
back
());
str_
+=
"]"
;
}
void
IrPrinter
::
Visit
(
const
Store
*
x
)
{
if
(
x
->
is_addr_tensor
())
{
auto
*
tensor_node
=
x
->
tensor
.
As
<
ir
::
_Tensor_
>
();
CHECK
(
tensor_node
);
str_
+=
tensor_node
->
name
;
}
else
if
(
x
->
is_addr_scalar
())
{
Visit
(
x
->
tensor
);
}
else
{
CINN_NOT_IMPLEMENTED
}
str_
+=
"["
;
for
(
std
::
size_t
i
=
0
;
i
+
1
<
x
->
indices
.
size
();
i
++
)
{
Visit
(
x
->
indices
[
i
]);
str_
+=
", "
;
}
if
(
!
x
->
indices
.
empty
())
Visit
(
x
->
indices
.
back
());
str_
+=
"] = "
;
Visit
(
x
->
value
);
}
void
IrPrinter
::
Visit
(
const
Free
*
x
)
{
auto
*
buffer
=
x
->
destination
.
As
<
ir
::
_Buffer_
>
();
CHECK
(
buffer
);
str_
+=
"free("
;
str_
+=
buffer
->
name
;
str_
+=
")"
;
}
void
IrPrinter
::
DoIndent
()
{
str_
+=
std
::
string
(
indent_
,
' '
);
}
void
IrPrinter
::
IncIndent
()
{
indent_
+=
indent_unit
;
}
void
IrPrinter
::
DecIndent
()
{
indent_
-=
indent_unit
;
}
void
IrPrinter
::
Visit
(
const
_Buffer_
*
x
)
{
std
::
vector
<
std
::
string
>
dim_names
;
std
::
transform
(
x
->
shape
.
begin
(),
x
->
shape
.
end
(),
std
::
back_inserter
(
dim_names
),
[
&
](
const
Expr
&
x
)
{
return
utils
::
GetStreamCnt
(
x
);
});
str_
+=
"_Buffer_<"
;
str_
+=
x
->
type
().
to_string
();
str_
+=
": "
;
str_
+=
utils
::
Join
(
dim_names
,
","
);
str_
+=
">("
;
str_
+=
x
->
name
;
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
_Tensor_
*
x
)
{
str_
+=
"Tensor("
;
str_
+=
x
->
name
;
str_
+=
", "
;
str_
+=
"["
;
if
(
!
x
->
shape
.
empty
())
{
for
(
std
::
size_t
i
=
0
;
i
+
1
<
x
->
shape
.
size
();
i
++
)
{
Visit
(
x
->
shape
[
i
]);
str_
+=
","
;
}
Visit
(
x
->
shape
.
back
());
}
str_
+=
"])"
;
}
void
IrPrinter
::
Visit
(
const
_LoweredFunc_
*
f
)
{
str_
+=
"function "
;
str_
+=
f
->
name
;
str_
+=
" "
;
std
::
vector
<
std
::
string
>
arg_names
;
for
(
auto
&
arg
:
f
->
args
)
{
arg_names
.
push_back
(
arg
.
name
());
}
str_
+=
"("
;
str_
+=
utils
::
Join
(
arg_names
,
", "
);
str_
+=
")
\n
"
;
Visit
(
f
->
body
);
}
void
IrPrinter
::
Visit
(
const
Let
*
f
)
{
CHECK
(
f
->
type
().
valid
());
str_
+=
f
->
type
().
to_string
();
str_
+=
" "
;
Visit
(
f
->
symbol
);
if
(
f
->
body
.
defined
())
{
str_
+=
" = "
;
Visit
(
f
->
body
);
}
}
void
IrPrinter
::
Visit
(
const
Reduce
*
f
)
{
str_
+=
"Reduce("
;
switch
(
f
->
reduce_type
)
{
case
Reduce
::
ReduceType
::
kSum
:
str_
+=
"sum"
;
break
;
case
Reduce
::
ReduceType
::
kSub
:
str_
+=
"sub"
;
break
;
case
Reduce
::
ReduceType
::
kDiv
:
str_
+=
"Div"
;
break
;
case
Reduce
::
ReduceType
::
kMul
:
str_
+=
"Mul"
;
break
;
case
Reduce
::
ReduceType
::
kMax
:
str_
+=
"Max"
;
break
;
case
Reduce
::
ReduceType
::
kMin
:
str_
+=
"Min"
;
break
;
case
Reduce
::
ReduceType
::
kAll
:
str_
+=
"&&"
;
break
;
case
Reduce
::
ReduceType
::
kAny
:
str_
+=
"||"
;
break
;
}
str_
+=
", "
;
Visit
(
f
->
body
);
str_
+=
","
;
Visit
(
f
->
init
);
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
Ramp
*
x
)
{
str_
+=
"Ramp("
;
Visit
(
x
->
base
);
str_
+=
","
;
Visit
(
x
->
stride
);
str_
+=
","
;
str_
+=
std
::
to_string
(
x
->
lanes
);
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
Broadcast
*
x
)
{
str_
+=
"Broadcast("
;
Visit
(
x
->
value
);
str_
+=
","
;
str_
+=
std
::
to_string
(
x
->
lanes
);
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
FracOp
*
x
)
{
str_
+=
"("
;
Visit
(
x
->
a
());
str_
+=
" / "
;
Visit
(
x
->
b
());
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
Product
*
x
)
{
str_
+=
"("
;
for
(
std
::
size_t
i
=
0
;
i
+
1
<
x
->
operands
().
size
();
i
++
)
{
Visit
(
x
->
operand
(
i
));
str_
+=
" * "
;
}
if
(
!
x
->
operands
().
empty
())
Visit
(
x
->
operands
().
back
());
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
Sum
*
x
)
{
str_
+=
"("
;
for
(
std
::
size_t
i
=
0
;
i
+
1
<
x
->
operands
().
size
();
i
++
)
{
Visit
(
x
->
operand
(
i
));
str_
+=
" + "
;
}
if
(
!
x
->
operands
().
empty
())
Visit
(
x
->
operands
().
back
());
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
PrimitiveNode
*
x
)
{
str_
+=
x
->
name
;
str_
+=
"("
;
std
::
vector
<
std
::
string
>
args_repr
;
for
(
auto
&
args
:
x
->
arguments
)
{
std
::
vector
<
std
::
string
>
arg_repr
;
for
(
auto
&
arg
:
args
)
{
arg_repr
.
push_back
(
utils
::
GetStreamCnt
(
arg
));
}
args_repr
.
push_back
(
utils
::
Join
(
arg_repr
,
","
));
}
str_
+=
utils
::
Join
(
args_repr
,
","
);
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
_BufferRange_
*
x
)
{
auto
*
buffer
=
x
->
buffer
.
As
<
ir
::
_Buffer_
>
();
CHECK
(
buffer
);
str_
+=
buffer
->
name
;
str_
+=
"["
;
for
(
std
::
size_t
i
=
0
;
i
<
x
->
ranges
.
size
();
i
++
)
{
if
(
i
)
str_
+=
", "
;
auto
&
range
=
x
->
ranges
[
i
];
str_
+=
range
->
name
;
str_
+=
"("
;
if
(
range
->
lower_bound
.
defined
())
{
Visit
(
range
->
lower_bound
);
str_
+=
":"
;
}
else
{
str_
+=
"undefined:"
;
}
if
(
range
->
upper_bound
.
defined
())
{
Visit
(
range
->
upper_bound
);
}
else
{
str_
+=
"undefined"
;
}
str_
+=
")"
;
}
str_
+=
"]"
;
}
void
IrPrinter
::
Visit
(
const
ScheduleBlock
*
x
)
{}
void
IrPrinter
::
Visit
(
const
ScheduleBlockRealize
*
x
)
{
auto
*
schedule_block
=
x
->
schedule_block
.
As
<
ScheduleBlock
>
();
str_
+=
"ScheduleBlock("
;
str_
+=
schedule_block
->
name
;
str_
+=
")
\n
"
;
DoIndent
();
str_
+=
"{
\n
"
;
// print block vars and bindings
auto
iter_vars
=
schedule_block
->
iter_vars
;
auto
iter_values
=
x
->
iter_values
;
CHECK_EQ
(
iter_vars
.
size
(),
iter_values
.
size
());
IncIndent
();
if
(
!
iter_vars
.
empty
())
DoIndent
();
for
(
std
::
size_t
i
=
0
;
i
<
iter_vars
.
size
();
i
++
)
{
if
(
i
)
str_
+=
", "
;
str_
+=
iter_vars
[
i
]
->
name
;
}
if
(
!
iter_vars
.
empty
())
str_
+=
" = axis.bind("
;
for
(
std
::
size_t
i
=
0
;
i
<
iter_values
.
size
();
i
++
)
{
if
(
i
)
str_
+=
", "
;
Visit
(
iter_values
[
i
]);
}
if
(
!
iter_vars
.
empty
())
str_
+=
")
\n
"
;
// print block body
if
(
!
schedule_block
->
read_buffers
.
empty
())
{
DoIndent
();
str_
+=
"read_buffers("
;
auto
&
read_buffers
=
schedule_block
->
read_buffers
;
for
(
std
::
size_t
i
=
0
;
i
<
read_buffers
.
size
();
i
++
)
{
if
(
i
)
str_
+=
", "
;
Visit
(
read_buffers
[
i
]);
}
str_
+=
")
\n
"
;
}
if
(
!
schedule_block
->
write_buffers
.
empty
())
{
DoIndent
();
str_
+=
"write_buffers("
;
auto
&
write_buffers
=
schedule_block
->
write_buffers
;
for
(
std
::
size_t
i
=
0
;
i
<
write_buffers
.
size
();
i
++
)
{
if
(
i
)
str_
+=
", "
;
Visit
(
write_buffers
[
i
]);
}
str_
+=
")
\n
"
;
}
if
(
!
schedule_block
->
attrs
.
empty
())
{
DoIndent
();
str_
+=
"attrs("
;
bool
comma
=
false
;
for
(
auto
&&
kv
:
schedule_block
->
attrs
)
{
if
(
comma
)
str_
+=
", "
;
str_
+=
kv
.
first
;
str_
+=
":"
;
absl
::
visit
(
[
this
](
auto
&&
arg
)
{
std
::
ostringstream
ss
;
ss
<<
arg
;
this
->
str_
+=
ss
.
str
();
},
kv
.
second
);
comma
=
true
;
}
str_
+=
")
\n
"
;
}
DoIndent
();
Visit
(
schedule_block
->
body
);
str_
+=
"
\n
"
;
DecIndent
();
DoIndent
();
str_
+=
"}"
;
}
void
IrPrinter
::
Visit
(
const
IntrinsicOp
*
x
)
{
switch
(
x
->
getKind
())
{
#define __(op__) \
case IntrinsicKind::k##op__: \
Visit(llvm::dyn_cast<intrinsics::op__>(x)); \
break;
INTRINSIC_KIND_FOR_EACH
(
__
)
#undef __
}
}
void
IrPrinter
::
Visit
(
const
intrinsics
::
BufferGetDataHandle
*
x
)
{
str_
+=
runtime
::
intrinsic
::
buffer_get_data_handle
;
Visit
(
x
->
buffer
);
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
intrinsics
::
BufferGetDataConstHandle
*
x
)
{
str_
+=
runtime
::
intrinsic
::
buffer_get_data_const_handle
;
Visit
(
x
->
buffer
);
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
intrinsics
::
PodValueToX
*
x
)
{
str_
+=
"pod_value_to_"
;
str_
+=
x
->
GetOutputType
(
0
).
to_string
();
str_
+=
"("
;
Visit
(
x
->
pod_value_ptr
);
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
intrinsics
::
BufferCreate
*
x
)
{
str_
+=
runtime
::
intrinsic
::
buffer_create
;
str_
+=
"()"
;
}
void
IrPrinter
::
Visit
(
const
intrinsics
::
GetAddr
*
x
)
{
str_
+=
"get_addr("
;
Visit
(
x
->
data
);
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
intrinsics
::
ArgsConstruct
*
x
)
{
str_
+=
runtime
::
intrinsic
::
args_construct_repr
;
str_
+=
"("
;
Visit
(
std
::
vector
<
Expr
>
(
x
->
args
.
begin
(),
x
->
args
.
end
()));
str_
+=
")"
;
}
void
IrPrinter
::
Visit
(
const
intrinsics
::
BuiltinIntrin
*
x
)
{
str_
+=
runtime
::
intrinsic
::
builtin_intrin_repr
;
str_
+=
"_"
;
str_
+=
x
->
name
;
str_
+=
"("
;
if
(
!
x
->
args
.
empty
())
{
for
(
std
::
size_t
i
=
0
;
i
+
1
<
x
->
args
.
size
();
i
++
)
{
Visit
(
x
->
args
[
i
]);
str_
+=
", "
;
}
Visit
(
x
->
args
.
back
());
}
str_
+=
")"
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
Expr
a
)
{
std
::
stringstream
ss
;
IrPrinter
printer
(
ss
);
printer
.
Print
(
a
);
os
<<
ss
.
str
();
return
os
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
Expr
>
&
a
)
{
std
::
stringstream
ss
;
IrPrinter
printer
(
ss
);
printer
.
Print
(
a
);
os
<<
ss
.
str
();
return
os
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ir
::
Module
&
m
)
{
os
<<
"Module "
<<
m
->
name
<<
" {
\n\n
"
;
for
(
auto
&
fn
:
m
->
functions
)
{
os
<<
fn
<<
'\n'
;
}
os
<<
"
\n\n
}"
;
return
os
;
}
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/utils/ir_printer.h
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace
cinn
{
namespace
lang
{
class
LoweredFunc
;
}
// namespace lang
namespace
ir
{
class
Module
;
struct
IrPrinter
:
public
IRVisitorRequireReImpl
<
void
>
{
explicit
IrPrinter
(
std
::
ostream
&
os
)
:
os_
(
os
),
str_
(
""
)
{}
//! Emit an expression on the output stream.
void
Print
(
const
Expr
&
e
);
//! Emit a expression list with , splitted.
void
Print
(
const
std
::
vector
<
Expr
>
&
exprs
,
const
std
::
string
&
splitter
=
", "
);
//! Emit a binary operator
template
<
typename
IRN
>
void
PrintBinaryOp
(
const
std
::
string
&
op
,
const
BinaryOpNode
<
IRN
>
*
x
);
//! Prefix the current line with `indent_` spaces.
void
DoIndent
();
//! Increase the indent size.
void
IncIndent
();
//! Decrease the indent size.
void
DecIndent
();
std
::
ostream
&
os
()
{
return
os_
;
}
void
Visit
(
const
Expr
&
x
)
{
IRVisitorRequireReImpl
::
Visit
(
&
x
);
}
void
Visit
(
const
std
::
vector
<
Expr
>
&
exprs
,
const
std
::
string
&
splitter
=
", "
)
{
for
(
std
::
size_t
i
=
0
;
!
exprs
.
empty
()
&&
i
+
1
<
exprs
.
size
();
i
++
)
{
Visit
(
exprs
[
i
]);
str_
+=
splitter
;
}
if
(
!
exprs
.
empty
())
Visit
(
exprs
.
back
());
}
#define __(op__) void Visit(const op__ *x) override;
NODETY_FORALL
(
__
)
#undef __
#define __(op__) virtual void Visit(const intrinsics::op__ *x);
INTRINSIC_KIND_FOR_EACH
(
__
)
#undef __
protected:
std
::
string
str_
;
private:
std
::
ostream
&
os_
;
uint16_t
indent_
{};
const
int
indent_unit
{
2
};
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
Expr
a
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
Expr
>
&
a
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Module
&
m
);
template
<
typename
IRN
>
void
IrPrinter
::
PrintBinaryOp
(
const
std
::
string
&
op
,
const
BinaryOpNode
<
IRN
>
*
x
)
{
str_
+=
"("
;
Visit
(
x
->
a
());
str_
+=
" "
;
str_
+=
op
;
str_
+=
" "
;
Visit
(
x
->
b
());
str_
+=
")"
;
}
}
// namespace ir
}
// namespace cinn
paddle/cinn/
optim
/ir_replace.cc
→
paddle/cinn/
ir/utils
/ir_replace.cc
View file @
01a10755
...
...
@@ -12,17 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/
optim
/ir_replace.h"
#include "paddle/cinn/
ir/utils
/ir_replace.h"
#include <set>
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
namespace
optim
{
namespace
ir
{
namespace
ir_utils
{
using
utils
::
GetStreamCnt
;
namespace
{
...
...
@@ -42,14 +43,14 @@ struct IrReplaceMutator : ir::IRMutator<Expr*> {
void
Visit
(
const
ir
::
_Var_
*
op
,
Expr
*
expr
)
override
{
if
(
op
->
node_type
()
==
from_
->
node_type
()
&&
from_repr_
==
GetStreamCnt
(
*
expr
))
{
*
expr
=
optim
::
IRCopy
(
to_
);
*
expr
=
ir
::
ir_utils
::
IRCopy
(
to_
);
}
}
void
Visit
(
const
ir
::
Broadcast
*
op
,
Expr
*
expr
)
override
{
if
(
op
->
node_type
()
==
from_
->
node_type
()
&&
from_repr_
==
GetStreamCnt
(
*
expr
))
{
*
expr
=
optim
::
IRCopy
(
to_
);
*
expr
=
ir
::
ir_utils
::
IRCopy
(
to_
);
}
}
...
...
@@ -65,5 +66,6 @@ void IrReplace(ir::Expr* expr, ir::Expr from, ir::Expr to) {
IrReplaceMutator
(
from
,
to
)(
expr
);
}
}
// namespace optim
}
// namespace ir_utils
}
// namespace ir
}
// namespace cinn
Prev
1
…
19
20
21
22
23
24
25
26
27
…
29
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment