Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
558
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
604 additions
and
108 deletions
+604
-108
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc
...uto_schedule/search_space/auto_gen_rule/auto_bind_test.cc
+1
-1
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc
...n/auto_schedule/search_space/auto_gen_rule/auto_inline.cc
+39
-8
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h
...nn/auto_schedule/search_space/auto_gen_rule/auto_inline.h
+0
-1
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc
...o_schedule/search_space/auto_gen_rule/auto_inline_test.cc
+13
-15
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc
...n/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc
+2
-2
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc
...o_schedule/search_space/auto_gen_rule/auto_unroll_test.cc
+4
-3
paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc
...uto_schedule/search_space/auto_gen_rule/mix_rules_test.cc
+1
-1
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc
...schedule/search_space/auto_gen_rule/multi_level_tiling.cc
+1
-1
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc
...ule/search_space/auto_gen_rule/multi_level_tiling_test.cc
+15
-21
paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.cc
...chedule/search_space/auto_gen_rule/reduction_factoring.cc
+203
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h
...schedule/search_space/auto_gen_rule/reduction_factoring.h
+59
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring_test.cc
...le/search_space/auto_gen_rule/reduction_factoring_test.cc
+224
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc
...uto_schedule/search_space/auto_gen_rule/skip_rule_test.cc
+7
-6
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc
...n/auto_schedule/search_space/auto_gen_rule/test_helper.cc
+4
-2
paddle/cinn/auto_schedule/search_space/search_space.cc
paddle/cinn/auto_schedule/search_space/search_space.cc
+1
-1
paddle/cinn/auto_schedule/search_space/search_state.cc
paddle/cinn/auto_schedule/search_space/search_state.cc
+5
-6
paddle/cinn/auto_schedule/search_space/search_state.h
paddle/cinn/auto_schedule/search_space/search_state.h
+3
-3
paddle/cinn/auto_schedule/search_space/search_state_test.cc
paddle/cinn/auto_schedule/search_space/search_state_test.cc
+10
-26
paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt
paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt
+3
-2
paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc
...cinn/auto_schedule/search_strategy/evolutionary_search.cc
+9
-9
No files found.
Too many changes to show.
To preserve performance only
558 of 558+
files are displayed.
Plain diff
Email patch
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc
View file @
01a10755
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
#include <numeric>
#include <numeric>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "test/cpp/cinn/program_builder.h"
#include "test/cpp/cinn/program_builder.h"
namespace
cinn
{
namespace
cinn
{
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc
View file @
01a10755
...
@@ -26,10 +26,11 @@
...
@@ -26,10 +26,11 @@
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
cinn
{
namespace
auto_schedule
{
namespace
auto_schedule
{
...
@@ -49,7 +50,12 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
...
@@ -49,7 +50,12 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir
::
Expr
root
=
ir_sch
->
GetRootBlock
(
sche_block_realize_expr
);
ir
::
Expr
root
=
ir_sch
->
GetRootBlock
(
sche_block_realize_expr
);
// Check the schedule block to be inlined is not a reduce tensor.
// Check the schedule block to be inlined is not a reduce tensor.
std
::
set
<
ir
::
Expr
>
find_store
=
ir
::
CollectIRNodesWithoutTensor
(
for
(
const
ir
::
Var
&
iter_var
:
sche_block
->
iter_vars
)
{
if
(
iter_var
->
is_reduce_axis
)
{
return
false
;
}
}
std
::
set
<
ir
::
Expr
>
find_store
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
});
compute_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
});
if
(
find_store
.
size
()
!=
1UL
)
{
if
(
find_store
.
size
()
!=
1UL
)
{
return
false
;
return
false
;
...
@@ -69,6 +75,29 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
...
@@ -69,6 +75,29 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
return
false
;
return
false
;
}
}
// the xxx_reduce_init block cannot be inlined.
if
(
ir
::
IsReduceInitTensorName
(
tensor
->
name
))
{
return
false
;
}
// Skip external calls
std
::
vector
<
ir
::
Expr
>
consumers
=
ir
::
GetConsumers
(
sche_block_realize_expr
,
root
);
for
(
const
ir
::
Expr
&
consumer
:
consumers
)
{
std
::
set
<
ir
::
Expr
>
find_load
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
consumer
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
body
,
[
&
](
const
ir
::
Expr
*
x
)
{
return
x
->
As
<
ir
::
Load
>
()
&&
x
->
As
<
ir
::
Load
>
()
->
tensor
.
as_tensor_ref
()
->
name
==
tensor
->
name
;
});
if
(
find_load
.
empty
())
{
return
false
;
}
}
// write_buffers.size() = 1 and read_buffers is empty, means const
// write_buffers.size() = 1 and read_buffers is empty, means const
// we can inline to consumer
// we can inline to consumer
if
(
sche_block
->
read_buffers
.
empty
())
{
if
(
sche_block
->
read_buffers
.
empty
())
{
...
@@ -76,17 +105,19 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
...
@@ -76,17 +105,19 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
}
}
// Check this schedule block is the only writer of the tensor.
// Check this schedule block is the only writer of the tensor.
find_store
=
ir
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
find_store
=
return
x
->
As
<
ir
::
Store
>
()
&&
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
(
x
->
As
<
ir
::
Store
>
()
->
tensor
).
as_tensor_ref
()
->
name
==
tensor
->
name
;
return
x
->
As
<
ir
::
Store
>
()
&&
});
(
x
->
As
<
ir
::
Store
>
()
->
tensor
).
as_tensor_ref
()
->
name
==
tensor
->
name
;
});
if
(
find_store
.
size
()
!=
1UL
)
{
if
(
find_store
.
size
()
!=
1UL
)
{
return
false
;
return
false
;
}
}
// Check there is no overlap between the buffers the schedule block reads and
// Check there is no overlap between the buffers the schedule block reads and
// writes.
// writes.
std
::
set
<
ir
::
Expr
>
find_load
=
std
::
set
<
ir
::
Expr
>
find_load
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
ir
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
compute_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Load
>
()
&&
x
->
As
<
ir
::
Load
>
()
->
tensor
==
tensor_expr
;
return
x
->
As
<
ir
::
Load
>
()
&&
x
->
As
<
ir
::
Load
>
()
->
tensor
==
tensor_expr
;
});
});
if
(
!
find_load
.
empty
())
{
if
(
!
find_load
.
empty
())
{
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h
View file @
01a10755
...
@@ -63,7 +63,6 @@ class AutoInline : public AutoGenRule {
...
@@ -63,7 +63,6 @@ class AutoInline : public AutoGenRule {
std
::
vector
<
SearchState
>
ApplyOnBlock
(
SearchState
state
,
std
::
vector
<
SearchState
>
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
override
;
const
std
::
string
&
block_name
)
override
;
private:
void
Apply
(
ir
::
IRSchedule
*
ir_schedule
,
ir
::
Expr
&
block_expr
);
// NOLINT
void
Apply
(
ir
::
IRSchedule
*
ir_schedule
,
ir
::
Expr
&
block_expr
);
// NOLINT
private:
private:
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc
View file @
01a10755
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/cinn.h"
...
@@ -30,9 +31,9 @@
...
@@ -30,9 +31,9 @@
#include "paddle/cinn/ir/function_base.h"
#include "paddle/cinn/ir/function_base.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/poly/stage.h"
...
@@ -59,16 +60,13 @@ TEST(AutoInline, SingleLoopInline) {
...
@@ -59,16 +60,13 @@ TEST(AutoInline, SingleLoopInline) {
ir
::
Tensor
C
=
Compute
(
ir
::
Tensor
C
=
Compute
(
{
M
},
[
&
](
Var
i
)
{
return
B
(
i
)
+
ir
::
Expr
(
1.
f
);
},
"C"
);
{
M
},
[
&
](
Var
i
)
{
return
B
(
i
)
+
ir
::
Expr
(
1.
f
);
},
"C"
);
poly
::
StageMap
stages
=
CreateStages
({
A
,
B
,
C
});
ast_gen_ius
::
TensorGroup
tensor_group
({
A
,
B
,
C
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"TestAutoInline_SingleLoopInline"
,
lang
::
LowerToAstVec
(
"TestAutoInline_SingleLoopInline"
,
stages
,
{
A
,
C
},
{
A
,
C
},
{},
&
tensor_Group
,
{},
target
);
nullptr
,
target
,
true
);
VLOG
(
6
)
<<
"Expr after lowering:"
;
VLOG
(
6
)
<<
"Expr after lowering:"
;
VLOG
(
6
)
<<
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
funcs
[
0
]
->
body
;
...
@@ -161,14 +159,14 @@ TEST(AutoInline, AddReluInline) {
...
@@ -161,14 +159,14 @@ TEST(AutoInline, AddReluInline) {
"inferdtype"
);
"inferdtype"
);
const
auto
&
shape_dict
=
graph
->
GetAttrs
<
const
auto
&
shape_dict
=
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
auto
op_lowerer
=
std
::
make_unique
<
hlir
::
framework
::
OpLowerer
>
(
auto
op_lowerer
=
dtype_dict
,
shape_dict
,
target
);
hlir
::
framework
::
CreateOpLowerer
(
dtype_dict
,
shape_dict
,
target
);
EXPECT_EQ
(
graph
->
fusion_groups
.
size
(),
1UL
);
EXPECT_EQ
(
graph
->
fusion_groups
.
size
(),
1UL
);
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
op_lowerer
->
Lower
(
graph
->
fusion_groups
[
0
],
op_lowerer
.
Lower
(
graph
->
fusion_groups
[
0
],
/*apply_op_schedule = */
false
,
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
);
/*apply_group_schedule=*/
false
);
VLOG
(
6
)
<<
"Expr before auto inline: "
<<
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Expr before auto inline: "
<<
funcs
[
0
]
->
body
;
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc
View file @
01a10755
...
@@ -18,10 +18,10 @@
...
@@ -18,10 +18,10 @@
#include <cstdlib>
#include <cstdlib>
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
cinn
{
namespace
auto_schedule
{
namespace
auto_schedule
{
...
@@ -56,7 +56,7 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
...
@@ -56,7 +56,7 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
return
false
;
return
false
;
};
};
auto
find_target_exprs
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_target_exprs
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
schedule_block
->
body
,
schedule_block
->
body
,
[
&
has_reduce_iter
,
&
has_nonserial_loop
](
const
Expr
*
x
)
{
[
&
has_reduce_iter
,
&
has_nonserial_loop
](
const
Expr
*
x
)
{
return
has_reduce_iter
(
x
)
||
has_nonserial_loop
(
x
);
return
has_reduce_iter
(
x
)
||
has_nonserial_loop
(
x
);
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc
View file @
01a10755
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/lower.h"
...
@@ -38,9 +39,9 @@ TEST(AutoUnroll, Init) {
...
@@ -38,9 +39,9 @@ TEST(AutoUnroll, Init) {
#else
#else
Target
target
=
common
::
DefaultHostTarget
();
Target
target
=
common
::
DefaultHostTarget
();
#endif
#endif
a
uto
st
a
ge
s
=
CreateStages
({
C
});
ast
_
ge
n_ius
::
TensorGroup
tensor_group
({
C
});
auto
funcs
=
cinn
::
lang
::
LowerVec
(
auto
funcs
=
"test_init"
,
stages
,
{
A
,
B
,
C
},
{},
{},
nullptr
,
target
,
true
);
cinn
::
lang
::
LowerToAstVec
(
"test_init"
,
{
A
,
B
,
C
},
&
tensor_group
,
target
);
auto
ast_expr
=
funcs
[
0
]
->
body
;
auto
ast_expr
=
funcs
[
0
]
->
body
;
ir
::
IRSchedule
init_schedule
(
ir
::
ModuleExpr
({
ast_expr
}));
ir
::
IRSchedule
init_schedule
(
ir
::
ModuleExpr
({
ast_expr
}));
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc
View file @
01a10755
...
@@ -20,8 +20,8 @@
...
@@ -20,8 +20,8 @@
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "test/cpp/cinn/program_builder.h"
#include "test/cpp/cinn/program_builder.h"
namespace
cinn
{
namespace
cinn
{
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc
View file @
01a10755
...
@@ -29,11 +29,11 @@
...
@@ -29,11 +29,11 @@
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
cinn
{
namespace
auto_schedule
{
namespace
auto_schedule
{
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc
View file @
01a10755
...
@@ -21,15 +21,16 @@
...
@@ -21,15 +21,16 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/poly/stage.h"
...
@@ -106,16 +107,9 @@ TEST(MultiLevelTile, SimpleLoops) {
...
@@ -106,16 +107,9 @@ TEST(MultiLevelTile, SimpleLoops) {
ir
::
Tensor
C
=
Compute
(
ir
::
Tensor
C
=
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
)
+
B
(
j
);
},
"C"
);
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
)
+
B
(
j
);
},
"C"
);
poly
::
StageMap
stages
=
CreateStages
({
C
});
ast_gen_ius
::
TensorGroup
tensor_group
({
C
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerToAstVec
(
lang
::
LowerVec
(
"TestMultiLevelTile_SimpleLoops"
,
"TestMultiLevelTile_SimpleLoops"
,
{
C
},
&
tensor_group
,
target
);
stages
,
{
C
},
{},
{},
nullptr
,
target
,
true
);
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Expr before MultiLevelTiling: "
;
VLOG
(
6
)
<<
"Expr before MultiLevelTiling: "
;
...
@@ -261,7 +255,7 @@ TEST_F(TestMultiLevelTiling, Matmul) {
...
@@ -261,7 +255,7 @@ TEST_F(TestMultiLevelTiling, Matmul) {
{
{
i0, i1 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)))
i0, i1 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)))
{
{
temp_matmul_out__reduce_init[
((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))
] = 0.00000000f
temp_matmul_out__reduce_init[
i0, i1
] = 0.00000000f
}
}
}
}
}
}
...
@@ -308,10 +302,10 @@ TEST_F(TestMultiLevelTiling, Matmul) {
...
@@ -308,10 +302,10 @@ TEST_F(TestMultiLevelTiling, Matmul) {
ScheduleBlock(temp_matmul_out_local_temp_buffer)
ScheduleBlock(temp_matmul_out_local_temp_buffer)
{
{
i0_0, i1_0, i2 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)))
i0_0, i1_0, i2 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)))
read_buffers(_temp_matmul_out[i
(undefined:undefined), j(undefined:undefined)], _X[i(undefined:undefined), reduce_k(undefined:undefined)], _Y[reduce_k(undefined:undefined), j(undefined:undefined
)])
read_buffers(_temp_matmul_out[i
0_0(0:32), i1_0(0:32)], _X[i0_0(0:32), i2(0:32)], _Y[i2(0:32), i1_0(0:32
)])
write_buffers(_temp_matmul_out[i
(undefined:undefined), j(undefined:undefined
)])
write_buffers(_temp_matmul_out[i
0_0(0:32), i1_0(0:32
)])
{
{
temp_matmul_out_local_temp_buffer[
((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = (temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] + (X_reshape_shared_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2))] * Y_reshape_shared_temp_buffer[((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)), ((32 * j_1) + ((32 * j_2) + j_3))
]))
temp_matmul_out_local_temp_buffer[
i0_0, i1_0] = (temp_matmul_out_local_temp_buffer[i0_0, i1_0] + (X_reshape_shared_temp_buffer[i0_0, i2] * Y_reshape_shared_temp_buffer[i2, i1_0
]))
}
}
}
}
}
}
...
@@ -453,7 +447,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
...
@@ -453,7 +447,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
{
pad_temp_0[i,
j, k, a
] = select(((
a
<
17
) and ((
a
>= 1) and ((
k
<
17
) and (
k
>= 1)))), input[i,
j
, (
-1 + k), (-1 + a
)], -3.40282347e+38f)
pad_temp_0[i
0
,
i1, i2, i3
] = select(((
i3
<
(1 + 16)
) and ((
i3
>= 1) and ((
i2
<
(1 + 16)
) and (
i2
>= 1)))), input[i
0
,
i1
, (
i2 - 1), (i3 - 1
)], -3.40282347e+38f)
}
}
}
}
}
}
...
@@ -477,7 +471,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
...
@@ -477,7 +471,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
{
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{
{
var_0__reduce_init[
((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)
] = -3.40282347e+38f
var_0__reduce_init[
i0_0, i1_0, i2_0, i3_0
] = -3.40282347e+38f
}
}
}
}
}
}
...
@@ -511,10 +505,10 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
...
@@ -511,10 +505,10 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
ScheduleBlock(var_0_local_temp_buffer)
ScheduleBlock(var_0_local_temp_buffer)
{
{
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i
(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined
)])
read_buffers(_var_0[i
0_1(0:2), i1_1(0:8), i2_1(0:8), i3_1(0:8)], _pad_temp_0[i0_1(0:2), i1_1(0:8
)])
write_buffers(_var_0[i
(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined
)])
write_buffers(_var_0[i
0_1(0:2), i1_1(0:8), i2_1(0:8), i3_1(0:8
)])
{
{
var_0_local_temp_buffer[
((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j
_1
)
,
((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a
_1
)
] = cinn_max(var_0_local_temp_buffer[
((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j
_1
)
,
((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a
_1
)
], pad_temp_0_shared_temp_buffer[
((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j
_1
)
, ((
8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))
), ((
8
*
(i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0)
)])
var_0_local_temp_buffer[
i0_1, i1
_1,
i2_1, i3
_1] = cinn_max(var_0_local_temp_buffer[
i0_1, i1
_1,
i2_1, i3
_1], pad_temp_0_shared_temp_buffer[
i0_1, i1
_1, ((
2 * i2_1) + i4
), ((
2
*
i3_1) + i5
)])
}
}
}
}
}
}
...
@@ -533,7 +527,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
...
@@ -533,7 +527,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
{
ScheduleBlock(var_0)
ScheduleBlock(var_0)
{
{
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((
4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4
)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((
i_0_j_0_k_0_a_0_fused % 4) + (4 * ((i_j_k_a_fused / 2) % 2)
)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h"
#include <glog/logging.h>
#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
namespace
cinn
{
namespace
auto_schedule
{
bool
ReductionFactoring
::
CanApply
(
const
std
::
string
&
block_name
,
ir
::
IRSchedule
*
ir_schedule
)
const
{
ir
::
Expr
block_expr
=
ir_schedule
->
GetBlock
(
block_name
);
ir
::
ScheduleBlockRealize
*
block_realize
=
block_expr
.
As
<
ir
::
ScheduleBlockRealize
>
();
CHECK_NOTNULL
(
block_realize
);
ir
::
ScheduleBlock
*
sch_block
=
block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
CHECK_NOTNULL
(
sch_block
);
AnalyzeScheduleBlockReadWriteBuffer
(
sch_block
);
// 1. The block must have write buffer
if
(
sch_block
->
write_buffers
.
empty
())
{
return
false
;
}
// 2. The block must have at least one reduce axis
const
std
::
vector
<
ir
::
Var
>&
iter_vars
=
sch_block
->
iter_vars
;
bool
find_reduce_axis
=
false
;
for
(
int
i
=
0
;
i
<
iter_vars
.
size
();
++
i
)
{
if
(
iter_vars
[
i
]
->
is_reduce_axis
)
{
find_reduce_axis
=
true
;
break
;
}
}
if
(
!
find_reduce_axis
)
{
return
false
;
}
// 3. Each loop's body only contains one sub loop or block, except reduce_init
// block
std
::
vector
<
ir
::
Expr
>
loops
=
ir_schedule
->
GetLoops
(
block_name
);
for
(
const
ir
::
Expr
&
loop
:
loops
)
{
const
ir
::
Expr
&
body
=
loop
.
As
<
ir
::
For
>
()
->
body
;
if
(
body
.
As
<
ir
::
Block
>
())
{
if
(
body
.
As
<
ir
::
Block
>
()
->
stmts
.
size
()
==
1
)
{
if
(
body
.
As
<
ir
::
Block
>
()
->
stmts
[
0
].
As
<
ir
::
For
>
()
==
nullptr
&&
body
.
As
<
ir
::
Block
>
()
->
stmts
[
0
].
As
<
ir
::
ScheduleBlockRealize
>
()
==
nullptr
)
{
return
false
;
}
}
else
if
(
body
.
As
<
ir
::
Block
>
()
->
stmts
.
size
()
==
2
)
{
if
(
body
.
As
<
ir
::
Block
>
()
->
stmts
[
0
].
As
<
ir
::
ScheduleBlockRealize
>
()
==
nullptr
||
!
ir
::
IsReduceInitTensorName
(
GetBlockName
(
body
.
As
<
ir
::
Block
>
()
->
stmts
[
0
])))
{
return
false
;
}
if
(
body
.
As
<
ir
::
Block
>
()
->
stmts
[
1
].
As
<
ir
::
For
>
()
==
nullptr
&&
body
.
As
<
ir
::
Block
>
()
->
stmts
[
1
].
As
<
ir
::
ScheduleBlockRealize
>
()
==
nullptr
)
{
return
false
;
}
}
else
{
return
false
;
}
}
else
if
(
body
.
As
<
ir
::
For
>
()
||
body
.
As
<
ir
::
ScheduleBlockRealize
>
())
{
continue
;
}
else
{
return
false
;
}
}
return
true
;
}
RuleApplyType
ReductionFactoring
::
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
{
return
this
->
CanApply
(
block_name
,
&
(
state
->
ir_schedule
))
?
RuleApplyType
::
kApply
:
RuleApplyType
::
kCannotApply
;
}
std
::
vector
<
SearchState
>
ReductionFactoring
::
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
{
SearchState
new_state
=
state
.
Copy
();
Apply
(
block_name
,
&
(
new_state
->
ir_schedule
));
return
{
new_state
};
}
void
ReductionFactoring
::
Apply
(
const
std
::
string
&
block_name
,
ir
::
IRSchedule
*
ir_schedule
)
{
ir
::
Expr
block
=
ir_schedule
->
GetBlock
(
block_name
);
std
::
vector
<
ir
::
Expr
>
all_loops
=
ir_schedule
->
GetLoops
(
block_name
);
std
::
vector
<
ir
::
Expr
>
new_loop_order
;
size_t
num_spatial_loops
=
0
;
size_t
num_reduction_loops
=
0
;
// 1. Add all spatial loops
std
::
unordered_set
<
std
::
string
>
reduce_loop_var_names
=
GetReduceLoopVarNames
(
block
);
for
(
const
ir
::
Expr
&
expr
:
all_loops
)
{
if
(
reduce_loop_var_names
.
count
(
expr
.
As
<
ir
::
For
>
()
->
loop_var
->
name
)
==
0
)
{
new_loop_order
.
push_back
(
expr
);
++
num_spatial_loops
;
}
}
// 2. Add all reduction loops
for
(
const
ir
::
Expr
&
expr
:
all_loops
)
{
if
(
reduce_loop_var_names
.
count
(
expr
.
As
<
ir
::
For
>
()
->
loop_var
->
name
)
>
0
)
{
new_loop_order
.
push_back
(
expr
);
++
num_reduction_loops
;
}
}
if
(
num_reduction_loops
==
0
)
{
return
;
}
// 3. Reorder if new_loop_order differs from the original order
CHECK_EQ
(
all_loops
.
size
(),
new_loop_order
.
size
());
for
(
int
i
=
0
;
i
<
all_loops
.
size
();
++
i
)
{
if
(
all_loops
[
i
].
As
<
ir
::
For
>
()
->
loop_var
->
name
!=
new_loop_order
[
i
].
As
<
ir
::
For
>
()
->
loop_var
->
name
)
{
ir_schedule
->
Reorder
(
new_loop_order
);
break
;
}
}
// 4. Fuse all reduction loops
ir
::
Expr
fused_reduce_loop
;
VLOG
(
6
)
<<
"before Fuse: "
<<
ir_schedule
->
GetModule
().
GetExprs
()[
0
];
if
(
num_reduction_loops
>
1
)
{
std
::
vector
<
int
>
reduction_loop_indices
;
for
(
int
i
=
num_spatial_loops
;
i
<
all_loops
.
size
();
++
i
)
{
reduction_loop_indices
.
push_back
(
i
);
}
CHECK_EQ
(
reduction_loop_indices
.
size
(),
num_reduction_loops
);
fused_reduce_loop
=
ir_schedule
->
Fuse
(
block_name
,
reduction_loop_indices
);
}
else
{
all_loops
=
ir_schedule
->
GetLoops
(
block_name
);
fused_reduce_loop
=
all_loops
.
back
();
}
// 5. Split the reduction loop into 2 part
VLOG
(
6
)
<<
"before Split: "
<<
ir_schedule
->
GetModule
().
GetExprs
()[
0
];
int
factor
=
1
;
int
max_factor
=
1024
;
int
extent
=
ir
::
GetLoopExtent
(
fused_reduce_loop
);
for
(
int
i
=
max_factor
;
i
>=
1
;
--
i
)
{
if
(
extent
%
i
==
0
)
{
factor
=
i
;
break
;
}
}
std
::
vector
<
cinn
::
ir
::
Expr
>
splited_reduction_loops
=
ir_schedule
->
Split
(
fused_reduce_loop
,
{
factor
,
-
1
});
// 6. Apply FactorizeReduction
VLOG
(
6
)
<<
"before FactorizeReduction: "
<<
ir_schedule
->
GetModule
().
GetExprs
()[
0
];
ir_schedule
->
FactorizeReduction
(
splited_reduction_loops
[
0
],
num_spatial_loops
);
VLOG
(
6
)
<<
"after FactorizeReduction: "
<<
ir_schedule
->
GetModule
().
GetExprs
()[
0
];
// 7. Loop fusion and cross thread reduction
std
::
vector
<
ir
::
Expr
>
rb_loops
=
ir_schedule
->
GetLoops
(
block_name
);
ir
::
Expr
rf_block
=
ir_schedule
->
GetBlock
(
block_name
+
"_rf"
);
ir_schedule
->
SimpleComputeAt
(
rf_block
,
rb_loops
.
back
());
rb_loops
=
ir_schedule
->
GetLoops
(
block_name
);
ir
::
Expr
rf_init_block
=
ir_schedule
->
GetBlock
(
block_name
+
"_rf__reduce_init"
);
ir_schedule
->
SimpleComputeAt
(
rf_init_block
,
rb_loops
.
back
());
if
(
*
target_
==
common
::
DefaultNVGPUTarget
())
{
rb_loops
=
ir_schedule
->
GetLoops
(
block_name
);
rf_block
=
ir_schedule
->
GetBlock
(
block_name
+
"_rf"
);
ir_schedule
->
Bind
(
rb_loops
.
back
(),
"threadIdx.x"
);
ir_schedule
->
SetBuffer
(
rf_block
,
"shared"
);
}
VLOG
(
6
)
<<
"Loop fusion and cross thread reduction: "
<<
ir_schedule
->
GetModule
().
GetExprs
()[
0
];
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
class
ReductionFactoring
:
public
AutoGenRule
{
public:
explicit
ReductionFactoring
(
const
common
::
Target
&
target
)
:
AutoGenRule
(
target
)
{}
~
ReductionFactoring
()
=
default
;
// In the future, we will no longer use this interface.
RuleApplyType
Init
(
ir
::
IRSchedule
*
init_schedule
)
override
{
return
RuleApplyType
::
kCannotApply
;
}
// In the future, we will no longer use this interface.
void
Apply
(
int
index
)
override
{
LOG
(
FATAL
)
<<
"This is a deprecated interface, please do not use it."
;
return
;
}
RuleApplyType
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
override
;
std
::
string
GetRuleName
()
const
override
{
return
"ReductionFactoring"
;
}
std
::
vector
<
SearchState
>
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
override
;
void
Apply
(
const
std
::
string
&
block_name
,
ir
::
IRSchedule
*
ir_schedule
);
private:
bool
CanApply
(
const
std
::
string
&
block_name
,
ir
::
IRSchedule
*
ir_schedule
)
const
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring_test.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cmath>
#include <functional>
#include <numeric>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "test/cpp/cinn/concrete_program_builder.h"
PD_DECLARE_bool
(
cinn_new_group_scheduler
);
namespace
cinn
{
namespace
auto_schedule
{
class
TestReductionFactoring
:
public
TestAutoGenRuleBase
{
public:
std
::
vector
<
std
::
string
>
default_input_names
=
{
"X"
};
std
::
vector
<
std
::
string
>
default_output_names
=
{
"out"
};
void
TestApplyOnReduce
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
reduce_dim
,
const
std
::
string
&
block_name
,
const
std
::
string
&
expected_ir
)
{
Initialize
(
common
::
DefaultNVGPUTarget
());
// In order to forcibly use the most basic Compute of reduction
FLAGS_cinn_new_group_scheduler
=
1
;
auto
test_program
=
tests
::
ReduceBuilder
().
Build
(
{{
"X"
,
shape
}},
{{
"reduce_dim"
,
reduce_dim
}});
// construct input parameter
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
test_program
);
SearchState
state
(
ir_schedule
,
0
,
{});
std
::
vector
<
ir
::
Expr
>
func_bodys
=
ir_schedule
.
GetModule
().
GetExprs
();
ASSERT_EQ
(
func_bodys
.
size
(),
1UL
);
VLOG
(
6
)
<<
"Original Expr:
\n
"
<<
func_bodys
[
0
];
// apply
ReductionFactoring
reduction_factoring
(
target_
);
ASSERT_EQ
(
reduction_factoring
.
AnalyseApplyType
(
state
,
block_name
),
RuleApplyType
::
kApply
);
auto
result
=
reduction_factoring
.
ApplyOnBlock
(
state
,
block_name
)[
0
];
std
::
vector
<
ir
::
Expr
>
exprs
=
result
->
ir_schedule
.
GetModule
().
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
std
::
stringstream
ir
;
ir
<<
exprs
[
0
];
VLOG
(
6
)
<<
"ReductionFactoring applied Expr: "
<<
exprs
[
0
];
// check
const
std
::
vector
<
ir
::
Expr
>&
blocks
=
ir_schedule
.
GetAllBlocks
();
CHECK_EQ
(
blocks
.
size
(),
2UL
);
CHECK_EQ
(
ir
.
str
(),
expected_ir
);
}
};
TEST_F
(
TestReductionFactoring
,
AnalyseApplyType
)
{
Context
::
Global
().
ResetNameId
();
Initialize
(
common
::
DefaultNVGPUTarget
());
auto
test_program
=
tests
::
OpBuilder
(
"elementwise_add"
).
Build
({{
"X"
,
{
4
,
5
}},
{
"Y"
,
{
4
,
5
}}});
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
test_program
);
VLOG
(
6
)
<<
"Original Expr:
\n
"
<<
ir_schedule
.
GetModule
().
GetExprs
()[
0
];
SearchState
state
(
ir_schedule
,
0
,
{});
ReductionFactoring
reduction_factoring
(
target_
);
EXPECT_EQ
(
reduction_factoring
.
AnalyseApplyType
(
state
,
"var_1"
),
RuleApplyType
::
kCannotApply
);
}
#ifdef CINN_WITH_CUDA
TEST_F
(
TestReductionFactoring
,
ApplyOnBlock1ReduceDim
)
{
Context
::
Global
().
ResetNameId
();
std
::
string
expected_ir
=
R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_0, 0, 64)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_0] = 0.00000000f
}
{
serial for (reduce_k_0_1, 0, 1)
{
ScheduleBlock(var_0_rf)
{
vreduce_k_0_0, i0_0, vreduce_k_0_1 = axis.bind(reduce_k_0_0, i, reduce_k_0_1)
var_0_rf[i0_0, vreduce_k_0_0] = (var_0_rf[i0_0, vreduce_k_0_0] + X[i0_0, (vreduce_k_0_0 + vreduce_k_0_1)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_0])
}
}
}
}
}
}
}
})"
;
TestApplyOnReduce
({
32
,
64
},
{
1
},
"var_0"
,
expected_ir
);
}
TEST_F
(
TestReductionFactoring
,
ApplyOnBlock2ReduceDim
)
{
Context
::
Global
().
ResetNameId
();
std
::
string
expected_ir
=
R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_fused, 0, 1024)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_fused] = 0.00000000f
}
{
serial for (reduce_k_0_reduce_k_1_fused_0, 0, 8)
{
ScheduleBlock(var_0_rf)
{
vreduce_k_0_reduce_k_1_fused, i0_0, vreduce_k_0_reduce_k_1_fused_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i, reduce_k_0_reduce_k_1_fused_0)
var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] + X[i0_0, (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) / 128), (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) % 128)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused])
}
}
}
}
}
}
}
})"
;
TestApplyOnReduce
({
32
,
64
,
128
},
{
1
,
2
},
"var_0"
,
expected_ir
);
}
TEST_F
(
TestReductionFactoring
,
ApplyOnBlock3ReduceDim
)
{
Context
::
Global
().
ResetNameId
();
std
::
string
expected_ir
=
R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 1024)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = 0.00000000f
}
{
serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused_0, 0, 256)
{
ScheduleBlock(var_0_rf)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i, reduce_k_0_reduce_k_1_reduce_k_2_fused_0)
var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] + X[i0_0, ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) / 64), ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) % 64), (((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) % 64)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused])
}
}
}
}
}
}
}
})"
;
TestApplyOnReduce
({
32
,
64
,
64
,
64
},
{
1
,
2
,
3
},
"var_0"
,
expected_ir
);
}
#endif
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc
View file @
01a10755
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir.h"
...
@@ -52,9 +53,9 @@ TEST(SkipRule, Basic) {
...
@@ -52,9 +53,9 @@ TEST(SkipRule, Basic) {
ir
::
Tensor
C
=
Compute
(
ir
::
Tensor
C
=
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
)
+
B
(
j
);
},
"C"
);
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
)
+
B
(
j
);
},
"C"
);
poly
::
StageMap
stages
=
CreateStages
({
C
});
ast_gen_ius
::
TensorGroup
tensor_group
({
C
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
"TestSkipRule_Basic"
,
stages
,
{
C
},
{},
{},
nullptr
,
target
,
true
);
lang
::
LowerToAstVec
(
"TestSkipRule_Basic"
,
{
C
},
&
tensor_group
,
target
);
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Expr before SkipRule: "
;
VLOG
(
6
)
<<
"Expr before SkipRule: "
;
...
@@ -101,9 +102,9 @@ TEST(SkipRule, ApplyOnSpecificBlock) {
...
@@ -101,9 +102,9 @@ TEST(SkipRule, ApplyOnSpecificBlock) {
ir
::
Tensor
C
=
Compute
(
ir
::
Tensor
C
=
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
)
+
B
(
j
);
},
"C"
);
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
)
+
B
(
j
);
},
"C"
);
poly
::
StageMap
stages
=
CreateStages
({
C
});
ast_gen_ius
::
TensorGroup
tensor_group
({
C
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
"TestSkipRule_Basic"
,
stages
,
{
C
},
{},
{},
nullptr
,
target
,
true
);
lang
::
LowerToAstVec
(
"TestSkipRule_Basic"
,
{
C
},
&
tensor_group
,
target
);
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Expr before SkipRule: "
;
VLOG
(
6
)
<<
"Expr before SkipRule: "
;
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc
View file @
01a10755
...
@@ -61,12 +61,14 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
...
@@ -61,12 +61,14 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
"inferdtype"
);
"inferdtype"
);
auto
&
shape_dict
=
graph
->
GetMutableAttrs
<
auto
&
shape_dict
=
graph
->
GetMutableAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
hlir
::
framework
::
OpLowerer
op_lowerer
(
dtype_dict
,
shape_dict
,
target_
);
auto
op_lowerer
=
hlir
::
framework
::
CreateOpLowerer
(
dtype_dict
,
shape_dict
,
target_
);
lowered_funcs_
=
lowered_funcs_
=
op_lowerer
.
Lower
(
graph
->
fusion_groups
.
front
(),
op_lowerer
.
Lower
(
graph
->
fusion_groups
.
front
(),
/*apply_op_schedule = */
apply_manual_schedule
,
/*apply_op_schedule = */
apply_manual_schedule
,
/*apply_group_schedule = */
apply_manual_schedule
);
/*apply_group_schedule = */
apply_manual_schedule
,
/*apply_pass = */
apply_manual_schedule
);
CHECK
(
!
lowered_funcs_
.
empty
())
<<
"lowered_funcs_ is empty"
;
CHECK
(
!
lowered_funcs_
.
empty
())
<<
"lowered_funcs_ is empty"
;
std
::
vector
<
Expr
>
bodys
;
std
::
vector
<
Expr
>
bodys
;
...
...
paddle/cinn/auto_schedule/search_space/search_space.cc
View file @
01a10755
...
@@ -34,7 +34,7 @@
...
@@ -34,7 +34,7 @@
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool
(
auto_schedule_use_cost_model
);
PD_
DECLARE_bool
(
auto_schedule_use_cost_model
);
namespace
cinn
{
namespace
cinn
{
namespace
auto_schedule
{
namespace
auto_schedule
{
...
...
paddle/cinn/auto_schedule/search_space/search_state.cc
View file @
01a10755
...
@@ -20,9 +20,9 @@
...
@@ -20,9 +20,9 @@
#include <vector>
#include <vector>
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/utils/functional.h"
#include "paddle/cinn/utils/functional.h"
#include "paddle/cinn/utils/string.h"
#include "paddle/cinn/utils/string.h"
...
@@ -133,11 +133,10 @@ bool SearchStateEqual::operator()(const SearchState& lhs,
...
@@ -133,11 +133,10 @@ bool SearchStateEqual::operator()(const SearchState& lhs,
// compare exprs size firstly
// compare exprs size firstly
if
(
lhs_exprs
.
size
()
!=
rhs_exprs
.
size
())
return
false
;
if
(
lhs_exprs
.
size
()
!=
rhs_exprs
.
size
())
return
false
;
// compare every expr one by one with ir::IrEqualVisitor
// compare every expr one by one with ir::
ir_utils::
IrEqualVisitor
for
(
int
i
=
0
;
i
<
lhs_exprs
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
lhs_exprs
.
size
();
++
i
)
{
ir
::
IrEqualVisitor
compartor
(
if
(
!
ir
::
ir_utils
::
IRCompare
(
lhs_exprs
[
i
],
rhs_exprs
[
i
],
true
))
/*allow_name_suffix_diff=*/
true
);
// ignore suffix difference in name
return
false
;
if
(
!
compartor
.
Compare
(
lhs_exprs
[
i
],
rhs_exprs
[
i
]))
return
false
;
}
}
return
true
;
return
true
;
}
}
...
...
paddle/cinn/auto_schedule/search_space/search_state.h
View file @
01a10755
...
@@ -20,9 +20,9 @@
...
@@ -20,9 +20,9 @@
#include "paddle/cinn/common/object.h"
#include "paddle/cinn/common/object.h"
#include "paddle/cinn/common/shared.h"
#include "paddle/cinn/common/shared.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_compare.h"
#include "paddle/cinn/ir/utils/ir_compare.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace
cinn
{
namespace
cinn
{
namespace
auto_schedule
{
namespace
auto_schedule
{
...
@@ -70,8 +70,8 @@ struct SearchStateHash {
...
@@ -70,8 +70,8 @@ struct SearchStateHash {
size_t
operator
()(
const
SearchState
&
s
)
const
;
size_t
operator
()(
const
SearchState
&
s
)
const
;
};
};
// SearchStateHash equal functor, use ir::IrEqualVisitor to compare
their AST
// SearchStateHash equal functor, use ir::
ir_utils::
IrEqualVisitor to compare
// struct and fields
//
their AST
struct and fields
struct
SearchStateEqual
{
struct
SearchStateEqual
{
bool
operator
()(
const
SearchState
&
lhs
,
const
SearchState
&
rhs
)
const
;
bool
operator
()(
const
SearchState
&
lhs
,
const
SearchState
&
rhs
)
const
;
};
};
...
...
paddle/cinn/auto_schedule/search_space/search_state_test.cc
View file @
01a10755
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/context.h"
...
@@ -35,35 +36,18 @@ TEST(TestSearchState, SearchStateHash_Equal) {
...
@@ -35,35 +36,18 @@ TEST(TestSearchState, SearchStateHash_Equal) {
ir
::
Tensor
C
=
lang
::
Compute
(
ir
::
Tensor
C
=
lang
::
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
,
j
)
+
B
(
i
,
j
);
},
"C"
);
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
,
j
)
+
B
(
i
,
j
);
},
"C"
);
ast_gen_ius
::
TensorGroup
const_group_1
({
A
,
B
});
cinn
::
common
::
Context
::
Global
().
ResetNameId
();
cinn
::
common
::
Context
::
Global
().
ResetNameId
();
auto
a_plus_const_funcs_1
=
lang
::
LowerVec
(
"A_plus_const"
,
auto
a_plus_const_funcs_1
=
poly
::
CreateStages
({
A
,
B
}),
lang
::
LowerToAstVec
(
"A_plus_const"
,
{
A
,
B
},
&
const_group_1
,
target
);
{
A
,
B
},
{},
{},
nullptr
,
target
,
true
);
cinn
::
common
::
Context
::
Global
().
ResetNameId
();
cinn
::
common
::
Context
::
Global
().
ResetNameId
();
auto
a_plus_const_funcs_2
=
lang
::
LowerVec
(
"A_plus_const"
,
ast_gen_ius
::
TensorGroup
const_group_2
({
A
,
B
});
poly
::
CreateStages
({
A
,
B
}),
auto
a_plus_const_funcs_2
=
{
A
,
B
},
lang
::
LowerToAstVec
(
"A_plus_const"
,
{
A
,
B
},
&
const_group_2
,
target
);
{},
{},
nullptr
,
target
,
true
);
cinn
::
common
::
Context
::
Global
().
ResetNameId
();
cinn
::
common
::
Context
::
Global
().
ResetNameId
();
auto
a_plus_b_funcs
=
lang
::
LowerVec
(
"A_plus_B"
,
ast_gen_ius
::
TensorGroup
plus_group
({
A
,
C
});
poly
::
CreateStages
({
A
,
C
}),
auto
a_plus_b_funcs
=
{
A
,
C
},
lang
::
LowerToAstVec
(
"A_plus_B"
,
{
A
,
C
},
&
plus_group
,
target
);
{},
{},
nullptr
,
target
,
true
);
std
::
string
a_plus_const_funcs_1_str
=
R"ROC(function A_plus_const (_A, _B)
std
::
string
a_plus_const_funcs_1_str
=
R"ROC(function A_plus_const (_A, _B)
{
{
...
...
paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt
View file @
01a10755
...
@@ -4,5 +4,6 @@ core_gather_headers()
...
@@ -4,5 +4,6 @@ core_gather_headers()
gather_srcs
(
cinnapi_src SRCS evolutionary_search.cc
)
gather_srcs
(
cinnapi_src SRCS evolutionary_search.cc
)
cinn_cc_test
(
test_evolutionary_search SRCS evolutionary_search_test.cc DEPS
# TODO(zhhsplendid): enable this test again
cinncore test_program_builder
)
#cinn_cc_test(test_evolutionary_search SRCS evolutionary_search_test.cc DEPS
# cinncore test_program_builder)
paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc
View file @
01a10755
...
@@ -36,7 +36,7 @@
...
@@ -36,7 +36,7 @@
#include "paddle/cinn/utils/sized_multi_set.h"
#include "paddle/cinn/utils/sized_multi_set.h"
#include "paddle/cinn/utils/string.h"
#include "paddle/cinn/utils/string.h"
DECLARE_bool
(
auto_schedule_use_cost_model
);
PD_
DECLARE_bool
(
auto_schedule_use_cost_model
);
namespace
cinn
{
namespace
cinn
{
namespace
auto_schedule
{
namespace
auto_schedule
{
...
@@ -134,7 +134,7 @@ std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(
...
@@ -134,7 +134,7 @@ std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
for
(
auto
&&
record
:
records
)
{
for
(
auto
&&
record
:
records
)
{
ir
::
IRSchedule
ir_sch
(
ir
::
IRSchedule
ir_sch
(
optim
::
IRCopy
(
task_registry
->
Get
(
task_key
)
->
module_expr
),
ir
::
ir_utils
::
IRCopy
(
task_registry
->
Get
(
task_key
)
->
module_expr
),
utils
::
ForkRandomState
(
&
rand_seed_
));
utils
::
ForkRandomState
(
&
rand_seed_
));
ir
::
ScheduleDesc
::
ReplayWithProto
(
record
.
trace
,
&
ir_sch
);
ir
::
ScheduleDesc
::
ReplayWithProto
(
record
.
trace
,
&
ir_sch
);
results
.
emplace_back
(
SearchState
(
std
::
move
(
ir_sch
),
record
.
predicted_cost
));
results
.
emplace_back
(
SearchState
(
std
::
move
(
ir_sch
),
record
.
predicted_cost
));
...
@@ -181,9 +181,9 @@ SearchState EvolutionarySearch::CrossOver(const SearchState& state1,
...
@@ -181,9 +181,9 @@ SearchState EvolutionarySearch::CrossOver(const SearchState& state1,
for
(
size_t
i
=
0
;
i
<
father_exprs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
father_exprs
.
size
();
++
i
)
{
if
(
utils
::
SampleUniformInt
(
0
,
2
,
&
rand_seed_
)
==
0
)
{
if
(
utils
::
SampleUniformInt
(
0
,
2
,
&
rand_seed_
)
==
0
)
{
cross_over_exprs
.
push_back
(
optim
::
IRCopy
(
father_exprs
[
i
]));
cross_over_exprs
.
push_back
(
ir
::
ir_utils
::
IRCopy
(
father_exprs
[
i
]));
}
else
{
}
else
{
cross_over_exprs
.
push_back
(
optim
::
IRCopy
(
mother_exprs
[
i
]));
cross_over_exprs
.
push_back
(
ir
::
ir_utils
::
IRCopy
(
mother_exprs
[
i
]));
}
}
}
}
auto
res
=
SearchState
(
ir
::
IRSchedule
(
ir
::
ModuleExpr
(
cross_over_exprs
),
auto
res
=
SearchState
(
ir
::
IRSchedule
(
ir
::
ModuleExpr
(
cross_over_exprs
),
...
@@ -216,12 +216,12 @@ SearchState EvolutionarySearch::Mutate(
...
@@ -216,12 +216,12 @@ SearchState EvolutionarySearch::Mutate(
// ir_schedule
// ir_schedule
const
auto
&
task_key
=
tune_task_
.
serialized_key
;
const
auto
&
task_key
=
tune_task_
.
serialized_key
;
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
ir
::
IRSchedule
new_
ir_sch
(
ir
::
IRSchedule
p
ir_sch
(
optim
::
IRCopy
(
task_registry
->
Get
(
task_key
)
->
module_expr
),
ir
::
ir_utils
::
IRCopy
(
task_registry
->
Get
(
task_key
)
->
module_expr
),
utils
::
ForkRandomState
(
rand_seed
));
utils
::
ForkRandomState
(
rand_seed
));
new_trace
.
Replay
(
&
new_
ir_sch
,
true
);
new_trace
.
Replay
(
&
p
ir_sch
,
true
);
ApplyPostScheduleRules
(
&
new_
ir_sch
,
post_schedule_rules_
);
ApplyPostScheduleRules
(
&
p
ir_sch
,
post_schedule_rules_
);
auto
res
=
SearchState
(
std
::
move
(
new_
ir_sch
));
auto
res
=
SearchState
(
std
::
move
(
p
ir_sch
));
VLOG
(
5
)
<<
JoinStatesDebugString
(
VLOG
(
5
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::Mutate"
,
{
state
,
res
},
/*verbose=*/
VLOG_IS_ON
(
6
));
"EvolutionarySearch::Mutate"
,
{
state
,
res
},
/*verbose=*/
VLOG_IS_ON
(
6
));
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment