Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
992bec46
Commit
992bec46
authored
Oct 08, 2023
by
“yuguo”
Browse files
2.5
parent
0259837d
Changes
357
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2265 additions
and
0 deletions
+2265
-0
paddle/cinn/auto_schedule/search_space/rule_sampler_test.cc
paddle/cinn/auto_schedule/search_space/rule_sampler_test.cc
+80
-0
paddle/cinn/auto_schedule/search_space/search_space.cc
paddle/cinn/auto_schedule/search_space/search_space.cc
+348
-0
paddle/cinn/auto_schedule/search_space/search_space.h
paddle/cinn/auto_schedule/search_space/search_space.h
+115
-0
paddle/cinn/auto_schedule/search_space/search_space_test.cc
paddle/cinn/auto_schedule/search_space/search_space_test.cc
+21
-0
paddle/cinn/auto_schedule/search_space/search_state.cc
paddle/cinn/auto_schedule/search_space/search_state.cc
+164
-0
paddle/cinn/auto_schedule/search_space/search_state.h
paddle/cinn/auto_schedule/search_space/search_state.h
+91
-0
paddle/cinn/auto_schedule/search_space/search_state_test.cc
paddle/cinn/auto_schedule/search_space/search_state_test.cc
+161
-0
paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt
paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt
+8
-0
paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc
...cinn/auto_schedule/search_strategy/evolutionary_search.cc
+372
-0
paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h
.../cinn/auto_schedule/search_strategy/evolutionary_search.h
+169
-0
paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc
...auto_schedule/search_strategy/evolutionary_search_test.cc
+214
-0
paddle/cinn/auto_schedule/search_strategy/mutate_rule/CMakeLists.txt
.../auto_schedule/search_strategy/mutate_rule/CMakeLists.txt
+5
-0
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.cc
.../auto_schedule/search_strategy/mutate_rule/mutate_rule.cc
+32
-0
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h
...n/auto_schedule/search_strategy/mutate_rule/mutate_rule.h
+51
-0
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.cc
..._schedule/search_strategy/mutate_rule/mutate_tile_size.cc
+156
-0
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h
...o_schedule/search_strategy/mutate_rule/mutate_tile_size.h
+36
-0
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc
...dule/search_strategy/mutate_rule/mutate_tile_size_test.cc
+140
-0
paddle/cinn/auto_schedule/task/CMakeLists.txt
paddle/cinn/auto_schedule/task/CMakeLists.txt
+8
-0
paddle/cinn/auto_schedule/task/task_creator.cc
paddle/cinn/auto_schedule/task/task_creator.cc
+58
-0
paddle/cinn/auto_schedule/task/task_creator.h
paddle/cinn/auto_schedule/task/task_creator.h
+36
-0
No files found.
Too many changes to show.
To preserve performance only
357 of 357+
files are displayed.
Plain diff
Email patch
paddle/cinn/auto_schedule/search_space/rule_sampler_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/rule_sampler.h"
#include <gtest/gtest.h>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h"
namespace
cinn
{
namespace
auto_schedule
{
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
std
::
vector
<
AutoGenRule
*>
GenerateTestRules
()
{
return
{
new
AutoUnroll
(
target
),
new
SkipRule
(
target
)};
}
TEST
(
RuleSampler
,
Make
)
{
std
::
vector
<
AutoGenRule
*>
rules
=
GenerateTestRules
();
auto
traversal_block_sampler
=
RuleSampler
::
Make
(
rules
,
true
,
"traversal"
);
ASSERT_STREQ
(
traversal_block_sampler
->
Name
(),
"traversal"
);
auto
probabilistic_block_sampler
=
RuleSampler
::
Make
(
rules
,
true
,
"probabilistic"
);
ASSERT_STREQ
(
probabilistic_block_sampler
->
Name
(),
"probabilistic"
);
}
TEST
(
TraversalRuleSampler
,
NextRule
)
{
std
::
vector
<
AutoGenRule
*>
rules
=
GenerateTestRules
();
auto
traversal_rule_sampler
=
RuleSampler
::
Make
(
rules
,
true
,
"traversal"
);
AutoGenRule
*
rule
=
traversal_rule_sampler
->
NextRule
();
ASSERT_EQ
(
"AutoUnroll"
,
rule
->
GetRuleName
());
rule
=
traversal_rule_sampler
->
NextRule
();
ASSERT_EQ
(
"SkipRule"
,
rule
->
GetRuleName
());
traversal_rule_sampler
->
Reset
();
rule
=
traversal_rule_sampler
->
NextRule
();
ASSERT_EQ
(
"AutoUnroll"
,
rule
->
GetRuleName
());
traversal_rule_sampler
=
RuleSampler
::
Make
(
rules
,
false
,
"traversal"
);
rule
=
traversal_rule_sampler
->
NextRule
();
ASSERT_EQ
(
"AutoUnroll"
,
rule
->
GetRuleName
());
rule
=
traversal_rule_sampler
->
NextRule
();
ASSERT_EQ
(
"AutoUnroll"
,
rule
->
GetRuleName
());
}
TEST
(
ProbabilisticRuleSampler
,
NextRule
)
{
std
::
vector
<
AutoGenRule
*>
rules
=
GenerateTestRules
();
auto
probabilistic_rule_sampler
=
RuleSampler
::
Make
(
rules
,
false
,
"probabilistic"
,
0
,
{
4
,
1
});
AutoGenRule
*
rule
;
for
(
int
i
=
0
;
i
<
20
;
++
i
)
{
rule
=
probabilistic_rule_sampler
->
NextRule
();
VLOG
(
6
)
<<
"next rule name: "
<<
rule
->
GetRuleName
();
}
probabilistic_rule_sampler
=
RuleSampler
::
Make
(
rules
,
true
,
"probabilistic"
,
0
,
{
4
,
1
});
probabilistic_rule_sampler
->
NextRule
();
probabilistic_rule_sampler
->
NextRule
();
ASSERT_EQ
(
nullptr
,
probabilistic_rule_sampler
->
NextRule
());
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/search_space.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/search_space.h"
#include <glog/logging.h>
#include <cstdlib>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/cost_model/expr_cost_model.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h"
#include "paddle/cinn/auto_schedule/search_space/block_sampler.h"
#include "paddle/cinn/auto_schedule/search_space/rule_sampler.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool
(
auto_schedule_use_cost_model
);
namespace
cinn
{
namespace
auto_schedule
{
SearchSpace
::
SearchSpace
(
const
TuneTask
&
tune_task
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
)
:
tune_task_
(
tune_task
),
rand_seed_
(
utils
::
LinearRandomEngine
::
NormalizeState
(
rand_seed
))
{
const
auto
&
target
=
tune_task_
.
target
;
// initialize a set of rules and they are commonly used by all states
// TODO(zhhsplendid): pass correct output names to AutoInline
// sketch_rules_.emplace_back(new AutoInline(target,
// tune_task_.output_names));
sketch_rules_
.
emplace_back
(
new
MultiLevelTiling
(
target
,
MultiLevelTiling
::
kConfigs
.
at
(
target
.
arch
)));
sketch_rules_
.
emplace_back
(
new
AutoUnroll
(
target
));
sketch_rules_
.
emplace_back
(
new
SkipRule
(
target
));
}
SearchState
SearchSpace
::
GetScheduleMutate
(
const
SearchState
&
state
,
const
ExprCostModel
&
cost_model
)
{
bool
has_manual_schedule
=
false
;
if
(
has_manual_schedule
)
{
SearchState
ret
=
ManualScheduleMutate
(
state
);
return
ret
;
}
SearchState
ret
=
RandomScheduleMutate
(
state
);
if
(
FLAGS_auto_schedule_use_cost_model
)
{
ret
->
predicted_cost
=
cost_model
.
Predict
(
ret
->
ir_schedule
.
GetModule
(),
tune_task_
.
target
);
}
VLOG
(
4
)
<<
JoinStatesDebugString
(
"SearchSpace::GetScheduleMutate"
,
{
state
},
/*verbose=*/
VLOG_IS_ON
(
5
));
return
ret
;
}
SearchState
SearchSpace
::
ManualScheduleMutate
(
const
SearchState
&
state
)
{
// TODO(zhhsplendid): Add manual schedule mutate
return
state
;
}
SearchState
SearchSpace
::
RandomScheduleMutate
(
const
SearchState
&
state
)
{
// 1. Found the schedules which can apply on this Expr
// 2. Make a distribution on those schedules
std
::
map
<
int
,
int
>
weight_to_rule_index
;
int
cur_weight
=
0
;
SearchState
ret
(
state
);
std
::
vector
<
RuleApplyType
>
apply_types
(
ret
->
applicable_rules
.
size
());
for
(
int
idx
=
0
;
idx
!=
ret
->
applicable_rules
.
size
();
++
idx
)
{
AutoGenRule
*
rule
=
ret
->
applicable_rules
.
at
(
idx
);
RuleApplyType
apply_type
=
rule
->
Init
(
&
ret
->
ir_schedule
);
VLOG
(
6
)
<<
"Evaluate rule:"
<<
rule
->
GetRuleName
()
<<
"="
<<
static_cast
<
int
>
(
apply_type
);
apply_types
[
idx
]
=
apply_type
;
if
(
apply_type
!=
RuleApplyType
::
kCannotApply
)
{
weight_to_rule_index
[
cur_weight
]
=
idx
;
cur_weight
+=
rule
->
NumberApplicable
();
}
}
if
(
weight_to_rule_index
.
empty
())
{
// No applicable rule, return the input mod_expr
VLOG
(
6
)
<<
"No applicable rule"
;
return
ret
;
}
// 3. Sample a schedule on the distribution
int
sample_weighted_index
=
utils
::
SampleUniformInt
(
0
,
cur_weight
,
&
rand_seed_
);
auto
iter
=
weight_to_rule_index
.
upper_bound
(
sample_weighted_index
);
--
iter
;
int
sample_rule_index
=
iter
->
second
;
CHECK_LT
(
sample_rule_index
,
ret
->
applicable_rules
.
size
());
AutoGenRule
*
sample_rule
=
ret
->
applicable_rules
.
at
(
sample_rule_index
);
VLOG
(
7
)
<<
"Apply rule: "
<<
sample_rule
->
GetRuleName
()
<<
" with index="
<<
sample_weighted_index
-
iter
->
first
;
// 4. Apply the schedule change
sample_rule
->
Apply
(
sample_weighted_index
-
iter
->
first
);
// 5. Remove the rule after applying it
if
(
apply_types
.
at
(
sample_rule_index
)
!=
RuleApplyType
::
kCannotApply
)
{
ret
->
applicable_rules
.
erase
(
ret
->
applicable_rules
.
begin
()
+
sample_rule_index
);
}
return
ret
;
}
std
::
vector
<
SearchState
>
SearchSpace
::
InitSketchWithRandomStrategy
(
int
num
)
{
VLOG
(
5
)
<<
"SearchSpace::GetRandomInitialSketch with num="
<<
num
;
ir
::
IRSchedule
init_schedule
(
ir
::
ModuleExpr
(
tune_task_
.
GetLoweredFuncBodyExprs
()),
utils
::
ForkRandomState
(
&
rand_seed_
));
std
::
vector
<
AutoGenRule
*>
init_rules
;
std
::
transform
(
sketch_rules_
.
begin
(),
sketch_rules_
.
end
(),
std
::
back_inserter
(
init_rules
),
[](
const
auto
&
rule
)
{
return
rule
.
get
();
});
std
::
vector
<
SearchState
>
result
;
while
(
result
.
size
()
<
num
)
{
SearchState
state
(
init_schedule
,
SearchState
::
NOT_INIT_COST
,
init_rules
);
for
(
int
i
=
0
;
i
<
init_sketch_random_depth_
;
++
i
)
{
VLOG
(
6
)
<<
"Generating random sketch with RandomScheduleMutate at depth: "
<<
i
;
state
=
RandomScheduleMutate
(
state
);
if
(
state
->
applicable_rules
.
empty
())
{
break
;
}
}
VLOG
(
5
)
<<
JoinStatesDebugString
(
"SearchSpace::GetRandomInitialSketch-New_Sketch"
,
{
state
},
/*verbose=*/
VLOG_IS_ON
(
6
));
result
.
emplace_back
(
std
::
move
(
state
));
}
return
result
;
}
std
::
vector
<
SearchState
>
SearchSpace
::
InitSketchWithRandomPrunedStrategy
()
{
VLOG
(
5
)
<<
"SearchSpace::InitSketchWithRandomPrunedStrategy"
;
ir
::
IRSchedule
init_schedule
(
ir
::
ModuleExpr
(
tune_task_
.
GetLoweredFuncBodyExprs
()),
utils
::
ForkRandomState
(
&
rand_seed_
));
auto
all_blocks
=
init_schedule
.
GetAllBlocks
();
auto
block_sampler
=
BlockSampler
::
Make
(
all_blocks
,
true
,
"probabilistic"
,
utils
::
ForkRandomState
(
&
rand_seed_
));
std
::
vector
<
AutoGenRule
*>
init_rules
;
std
::
transform
(
sketch_rules_
.
begin
(),
sketch_rules_
.
end
()
-
1
,
std
::
back_inserter
(
init_rules
),
[](
const
auto
&
rule
)
{
return
rule
.
get
();
});
CHECK
(
init_rules
.
size
()
>
0
)
<<
"number of init rules cannot be 0"
;
SearchState
init_state
(
init_schedule
,
SearchState
::
NOT_INIT_COST
,
{});
std
::
vector
<
SearchState
>
states_buf1
{
init_state
},
states_buf2
;
std
::
vector
<
SearchState
>*
p_states_cur
=
&
states_buf1
;
std
::
vector
<
SearchState
>*
p_states_next
=
&
states_buf2
;
int
total_steps
=
0
,
steps
;
std
::
string
block_name
;
while
(
""
!=
(
block_name
=
block_sampler
->
NextBlock
())
&&
total_steps
<
init_sketch_random_depth_
)
{
steps
=
utils
::
SampleUniformInt
(
1
,
init_rules
.
size
()
+
1
,
&
rand_seed_
);
if
(
total_steps
+
steps
>
init_sketch_random_depth_
)
{
steps
=
init_sketch_random_depth_
-
total_steps
;
}
total_steps
+=
steps
;
p_states_next
->
clear
();
for
(
const
auto
&
state
:
*
p_states_cur
)
{
auto
rule_sampler
=
RuleSampler
::
Make
(
init_rules
,
true
,
"probabilistic"
,
utils
::
ForkRandomState
(
&
rand_seed_
));
auto
new_states
=
ApplySketchRule
(
state
,
block_name
,
rule_sampler
.
get
(),
steps
,
false
,
1
);
p_states_next
->
insert
(
p_states_next
->
end
(),
new_states
.
begin
(),
new_states
.
end
());
}
std
::
swap
(
p_states_cur
,
p_states_next
);
}
VLOG
(
5
)
<<
JoinStatesDebugString
(
"SearchSpace::InitSketchWithRandomPrunedStrategy"
,
*
p_states_cur
,
/*verbose=*/
VLOG_IS_ON
(
6
));
return
*
p_states_cur
;
}
std
::
vector
<
SearchState
>
SearchSpace
::
InitSketchWithRulePrunedStrategy
()
{
VLOG
(
5
)
<<
"SearchSpace::InitSketchWithRulePrunedStrategy"
;
ir
::
IRSchedule
init_schedule
(
ir
::
ModuleExpr
(
tune_task_
.
GetLoweredFuncBodyExprs
()),
utils
::
ForkRandomState
(
&
rand_seed_
));
auto
all_blocks
=
init_schedule
.
GetAllBlocks
();
std
::
reverse
(
all_blocks
.
begin
(),
all_blocks
.
end
());
auto
block_sampler
=
BlockSampler
::
Make
(
all_blocks
,
true
,
"traversal"
);
std
::
vector
<
AutoGenRule
*>
init_rules
;
std
::
transform
(
sketch_rules_
.
begin
(),
sketch_rules_
.
end
()
-
1
,
std
::
back_inserter
(
init_rules
),
[](
const
auto
&
rule
)
{
return
rule
.
get
();
});
CHECK
(
init_rules
.
size
()
>
0
)
<<
"number of init rules cannot be 0"
;
SearchState
init_state
(
init_schedule
,
SearchState
::
NOT_INIT_COST
,
{});
std
::
vector
<
SearchState
>
states_buf1
{
init_state
},
states_buf2
;
std
::
vector
<
SearchState
>*
p_states_cur
=
&
states_buf1
;
std
::
vector
<
SearchState
>*
p_states_next
=
&
states_buf2
;
std
::
string
block_name
;
while
(
""
!=
(
block_name
=
block_sampler
->
NextBlock
()))
{
p_states_next
->
clear
();
for
(
const
auto
&
state
:
*
p_states_cur
)
{
auto
rule_sampler
=
RuleSampler
::
Make
(
init_rules
,
true
,
"traversal"
);
auto
new_states
=
ApplySketchRule
(
state
,
block_name
,
rule_sampler
.
get
(),
0
,
true
);
p_states_next
->
insert
(
p_states_next
->
end
(),
new_states
.
begin
(),
new_states
.
end
());
}
std
::
swap
(
p_states_cur
,
p_states_next
);
}
VLOG
(
5
)
<<
JoinStatesDebugString
(
"SearchSpace::InitSketchWithRulePrunedStrategy"
,
*
p_states_cur
,
/*verbose=*/
VLOG_IS_ON
(
6
));
return
*
p_states_cur
;
}
std
::
vector
<
SearchState
>
SearchSpace
::
GenerateSketches
(
int
num
,
const
std
::
string
&
strategy
)
{
VLOG
(
4
)
<<
"SearchSpace::GenerateSketches with num = "
<<
num
;
if
(
strategy
==
"random"
)
{
return
InitSketchWithRandomStrategy
(
num
);
}
std
::
vector
<
SearchState
>
result
;
while
(
result
.
size
()
<
num
)
{
std
::
vector
<
SearchState
>
sketchs
;
if
(
strategy
==
"rule_prune"
)
{
sketchs
=
InitSketchWithRulePrunedStrategy
();
}
else
if
(
strategy
==
"random_prune"
)
{
sketchs
=
InitSketchWithRandomPrunedStrategy
();
}
else
{
LOG
(
FATAL
)
<<
"Unimplemented init sketch strategy"
;
}
// the more rules are applied, the greater the possibility of good results,
// the more rules are applied, the more they are saved behind the queue,
// so we give priority to the results in the rear
for
(
auto
iter
=
sketchs
.
rbegin
();
iter
!=
sketchs
.
rend
();
++
iter
)
{
result
.
push_back
(
*
iter
);
if
(
result
.
size
()
==
num
)
{
break
;
}
}
}
VLOG
(
4
)
<<
JoinStatesDebugString
(
"SearchSpace::GenerateSketches"
,
result
,
/*verbose=*/
VLOG_IS_ON
(
5
));
return
result
;
}
std
::
vector
<
SearchState
>
SearchSpace
::
ApplySketchRule
(
const
SearchState
&
state
,
const
std
::
string
&
block_name
,
RuleSampler
*
rule_sampler
,
int
steps
,
bool
prune_by_rule
,
double
prune_probability
)
{
std
::
list
<
SearchState
>
layer
{
state
};
int
step
=
0
;
AutoGenRule
*
rule
;
// After determining a SearchState and a block, each rule has two
// possibilities: apply and not apply. In all transfer spaces, select a rule
// at each step, and collect all possible new states arrived by apply and not
// apply. This forms a tree, and we can use rule pruning or random pruning to
// reduce the number of sketches.
VLOG
(
6
)
<<
"Collect the states of all transfers within steps: "
<<
steps
;
while
((
step
++
<
steps
||
steps
==
0
)
&&
(
rule
=
rule_sampler
->
NextRule
()))
{
VLOG
(
7
)
<<
"step = "
<<
step
<<
", rule: "
<<
rule
->
GetRuleName
();
std
::
list
<
SearchState
>
new_states
;
int
id
=
0
;
for
(
std
::
list
<
SearchState
>::
iterator
iter
=
layer
.
begin
();
iter
!=
layer
.
end
();)
{
// Some rules will reduce the number of blocks, such as AutoInline,
// so we need to check whether the SearchState still has the block.
if
(
!
(
*
iter
)
->
ir_schedule
.
HasBlock
(
block_name
))
{
++
iter
;
continue
;
}
auto
type
=
rule
->
AnalyseApplyType
(
*
iter
,
block_name
);
VLOG
(
7
)
<<
"At SearchState "
<<
++
id
<<
", apply type = "
<<
static_cast
<
typename
std
::
underlying_type
<
RuleApplyType
>::
type
>
(
type
);
// if cannot apply the rule, skip it
if
(
type
==
RuleApplyType
::
kCannotApply
)
{
++
iter
;
continue
;
}
// if can apply the rule, apply it and determine whether to prune the
// branch that do not apply
std
::
vector
<
SearchState
>
tmp_states
=
rule
->
ApplyOnBlock
(
*
iter
,
block_name
);
new_states
.
insert
(
new_states
.
end
(),
tmp_states
.
begin
(),
tmp_states
.
end
());
bool
need_prune
=
false
;
if
(
prune_by_rule
)
{
need_prune
=
(
type
==
RuleApplyType
::
kApplyAndPruneOtherRules
);
}
else
{
need_prune
=
(
utils
::
SampleUniformDouble
(
0
,
1
,
&
rand_seed_
)
<
prune_probability
);
}
if
(
need_prune
)
{
iter
=
layer
.
erase
(
iter
);
}
else
{
++
iter
;
}
}
VLOG
(
7
)
<<
"apply on block: "
<<
block_name
<<
", generate "
<<
new_states
.
size
()
<<
" new states at step "
<<
step
;
layer
.
splice
(
layer
.
end
(),
std
::
move
(
new_states
));
}
VLOG
(
6
)
<<
"apply on block: "
<<
block_name
<<
", generate "
<<
layer
.
size
()
-
1
<<
" more states at all"
;
return
std
::
vector
<
SearchState
>
(
layer
.
begin
(),
layer
.
end
());
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/search_space.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/cost_model/expr_cost_model.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/rule_sampler.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
/**
* This class is an abstraction of the transformations can be applied to
* ir::Expr during auto-tuning. The transformation can be:
*
* 1. Manual defined schedule
* 2. Schedule generated by AutoGenRule
*
* TODO(zhhsplendid): de-duplication the generated ModuleExpr
*/
class
SearchSpace
{
public:
SearchSpace
(
const
TuneTask
&
tune_task
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
-
1
);
// Sketch mutate, returns the mutated ModuleExpr and estimited cost
virtual
SearchState
GetScheduleMutate
(
const
SearchState
&
state
,
const
ExprCostModel
&
cost_model
);
/**
* \brief Generate sketch as initial population of evolutionary search.
* @param num The number of sketches to generate.
* @param strategy The strategy to generate sketchs,
* Current optional strategies are "rule_prune" or "random_prune" or
* "random".
* - "rule_prune": will use rules to prune and generate sketches as
* efficiently as possible.
* - "random_prune": will use the new interface ApplySketchRules() to simulate
* the random generation of sketches, and supports the function of a rule
* returning multiple SearchStates and random pruning by probability.
* - "random": will randomly select a block and a rule to apply and repeat
* this step several times, however, each rule can only be used on one
* SearchState at most once.
* @return Generated sketchs.
*/
virtual
std
::
vector
<
SearchState
>
GenerateSketches
(
int
num
,
const
std
::
string
&
strategy
);
private:
// TODO(zhhsplendid): mutate by manual schedule.
SearchState
ManualScheduleMutate
(
const
SearchState
&
state
);
// mutate by sketch rules randomly
SearchState
RandomScheduleMutate
(
const
SearchState
&
state
);
// Generate num sketchs, each with several rounds of SketchMutate
std
::
vector
<
SearchState
>
InitSketchWithRandomStrategy
(
int
num
);
// Generate sketch pruned randomly as initial population of evolutionary
// search
std
::
vector
<
SearchState
>
InitSketchWithRandomPrunedStrategy
();
// Generate sketch pruned by rules as initial population of evolutionary
// search
std
::
vector
<
SearchState
>
InitSketchWithRulePrunedStrategy
();
/**
* @brief Collect the new states that may be transferred to after applying
* several rules on a block from a certain state.
* @param state Starting point of state transition.
* @param block_name Name of the block to apply the rules to.
* @param rule_sampler Sampler that samples the new rule to apply on the
* block.
* @param steps Number of steps to apply the rule.
* @param prune_by_rule If true, prune the state transition tree by rule,
* otherwise prune randomly.
* @param prune_probability Pruning probability of random pruning.
*/
std
::
vector
<
SearchState
>
ApplySketchRule
(
const
SearchState
&
state
,
const
std
::
string
&
block_name
,
RuleSampler
*
rule_sampler
,
int
steps
,
bool
prune_by_rule
,
double
prune_probability
=
1
);
private:
const
TuneTask
&
tune_task_
;
int
init_sketch_random_depth_
=
6
;
// supported AutoGenRules, every task holds a set
std
::
vector
<
std
::
unique_ptr
<
AutoGenRule
>>
sketch_rules_
;
utils
::
LinearRandomEngine
::
StateType
rand_seed_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/search_space_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/search_space.h"
#include <gtest/gtest.h>
namespace
cinn
{
namespace
auto_schedule
{}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/search_state.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/utils/functional.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
namespace
auto_schedule
{
SearchState
::
SearchState
(
ir
::
IRSchedule
ir_sch
,
float
cost
,
const
std
::
vector
<
AutoGenRule
*>&
rules
)
:
common
::
Shared
<
_SearchState_
>
(
common
::
make_shared
<
_SearchState_
>
())
{
auto
*
state
=
get
();
state
->
ir_schedule
=
std
::
move
(
ir_sch
);
state
->
applicable_rules
=
rules
;
state
->
predicted_cost
=
cost
;
}
SearchState
SearchState
::
Copy
()
const
{
return
SearchState
((
*
this
)
->
ir_schedule
,
(
*
this
)
->
predicted_cost
,
{});
}
std
::
string
_SearchState_
::
DebugString
()
const
{
const
auto
&
exprs
=
ir_schedule
.
GetModule
().
GetExprs
();
std
::
stringstream
module_stream
;
for
(
auto
i
=
0
;
i
<
exprs
.
size
();
++
i
)
{
module_stream
<<
"Expr "
<<
i
<<
" {
\n
"
<<
exprs
.
at
(
i
)
<<
"
\n
} // end Expr"
;
}
const
char
*
fmt_str
=
R"ROC(
ModuleExpr {
%s
} // end ModuleExpr
ScheduleDesc {
%s
} // end ScheduleDesc
predicted_cost: %f)ROC"
;
return
utils
::
StringFormat
(
fmt_str
,
module_stream
.
str
().
c_str
(),
ir_schedule
.
GetTraceDesc
().
DebugString
().
c_str
(),
predicted_cost
);
}
bool
operator
<
(
const
SearchState
&
left
,
const
SearchState
&
right
)
{
return
left
->
predicted_cost
<
right
->
predicted_cost
;
}
// Visit every node by expanding all of their fields in dfs order
class
DfsWithExprsFields
:
public
ir
::
IRVisitorRequireReImpl
<
void
>
{
protected:
#define __m(t__) \
void Visit(const ir::t__* x) override { \
for (auto* n : x->expr_fields()) { \
if (n->defined()) { \
Visit(n); \
} \
} \
}
NODETY_FORALL
(
__m
)
#undef __m
void
Visit
(
const
Expr
*
expr
)
override
{
IRVisitorRequireReImpl
::
Visit
(
expr
);
}
};
// Generate a reduce hash of a AST tree by combining hash of each AST node
class
IrNodesStructuralHash
:
public
DfsWithExprsFields
{
public:
explicit
IrNodesStructuralHash
(
size_t
init_key
)
:
hash_key_
(
init_key
)
{}
size_t
operator
()(
const
Expr
*
expr
)
{
Visit
(
expr
);
return
hash_key_
;
}
void
Visit
(
const
Expr
*
expr
)
override
{
static
decltype
(
ir
::
kIrNodeTyReprs
)
Node2Name
=
ir
::
kIrNodeTyReprs
;
if
(
!
expr
->
defined
())
return
;
auto
type_code
=
static_cast
<
IrNodeTyUnderlyingType
>
(
expr
->
node_type
());
hash_key_
=
utils
::
HashCombine
(
hash_key_
,
type_code
);
DfsWithExprsFields
::
Visit
(
expr
);
}
private:
void
Visit
(
const
ir
::
_Tensor_
*
x
)
override
{
for
(
auto
&
e
:
x
->
shape
)
{
Visit
(
&
e
);
}
DfsWithExprsFields
::
Visit
(
x
->
buffer
.
As
<
ir
::
_Buffer_
>
());
}
using
IrNodeTyUnderlyingType
=
std
::
underlying_type
<
ir
::
IrNodeTy
>::
type
;
size_t
hash_key_
;
};
size_t
SearchStateHash
::
operator
()(
const
SearchState
&
s
)
const
{
size_t
hash_key
=
0
;
const
auto
&
exprs
=
s
->
ir_schedule
.
GetModule
().
GetExprs
();
for
(
auto
&&
expr
:
exprs
)
{
hash_key
=
IrNodesStructuralHash
(
hash_key
)(
&
expr
);
}
return
hash_key
;
}
bool
SearchStateEqual
::
operator
()(
const
SearchState
&
lhs
,
const
SearchState
&
rhs
)
const
{
const
auto
&
lhs_exprs
=
lhs
->
ir_schedule
.
GetModule
().
GetExprs
();
const
auto
&
rhs_exprs
=
rhs
->
ir_schedule
.
GetModule
().
GetExprs
();
// compare exprs size firstly
if
(
lhs_exprs
.
size
()
!=
rhs_exprs
.
size
())
return
false
;
// compare every expr one by one with ir::IrEqualVisitor
for
(
int
i
=
0
;
i
<
lhs_exprs
.
size
();
++
i
)
{
ir
::
IrEqualVisitor
compartor
(
/*allow_name_suffix_diff=*/
true
);
// ignore suffix difference in name
if
(
!
compartor
.
Compare
(
lhs_exprs
[
i
],
rhs_exprs
[
i
]))
return
false
;
}
return
true
;
}
std
::
string
JoinStatesDebugString
(
const
std
::
string
&
title
,
const
std
::
vector
<
SearchState
>&
states
,
bool
verbose
)
{
std
::
stringstream
ss
;
ss
<<
title
<<
" states size:"
<<
states
.
size
()
<<
"
\n
"
;
SearchStateHash
state_hasher
;
for
(
size_t
i
=
0
;
i
<
states
.
size
();
++
i
)
{
uint64_t
hash_key
=
state_hasher
(
states
[
i
]);
if
(
verbose
)
{
ss
<<
"
\t
State-"
<<
i
<<
" hash:"
<<
hash_key
<<
"
\t
content:------>"
<<
states
[
i
]
->
DebugString
()
<<
"
\n
<------"
;
}
else
{
ss
<<
"
\t
State-"
<<
i
<<
" hash:"
<<
hash_key
<<
"
\n
"
;
}
}
return
std
::
move
(
*
ss
.
rdbuf
()).
str
();
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/search_state.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <limits>
#include <vector>
#include "paddle/cinn/common/object.h"
#include "paddle/cinn/common/shared.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_compare.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace
cinn
{
namespace
auto_schedule
{
struct
_SearchState_
;
class
AutoGenRule
;
//! Shared Wrapper for _SearchState_
class
SearchState
:
public
common
::
Shared
<
_SearchState_
>
{
public:
SearchState
()
=
default
;
// create a new SearchState
explicit
SearchState
(
ir
::
IRSchedule
ir_sch
,
float
cost
=
NOT_INIT_COST
,
const
std
::
vector
<
AutoGenRule
*>&
rules
=
{});
// Constant standing for a cost not being initialized
static
constexpr
float
NOT_INIT_COST
=
std
::
numeric_limits
<
float
>::
max
();
// compare function for two states
friend
bool
operator
<
(
const
SearchState
&
left
,
const
SearchState
&
right
);
// Deep copy a SearchState
SearchState
Copy
()
const
;
};
//! Class to store immediate states during search
struct
_SearchState_
:
public
common
::
Object
{
// IRSchedule contains ir::ModuleExpr and trace scheduling process
ir
::
IRSchedule
ir_schedule
;
// Cost model predicted cost
float
predicted_cost
;
// The rules that can be applied to the IRSchedule at this state.
std
::
vector
<
AutoGenRule
*>
applicable_rules
;
// return detail string of content for debug;
std
::
string
DebugString
()
const
;
const
char
*
type_info
()
const
override
{
return
__type_info__
;
}
static
constexpr
char
*
__type_info__
=
"auto_schedule_state"
;
};
// SearchStateHash hash functor that visits every AST node and combine their
// hash of node_type in dfs order
struct
SearchStateHash
{
size_t
operator
()(
const
SearchState
&
s
)
const
;
};
// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST
// struct and fields
struct
SearchStateEqual
{
bool
operator
()(
const
SearchState
&
lhs
,
const
SearchState
&
rhs
)
const
;
};
/*!
* \brief concatenate debug strings of all states with additional info
* \param title head of the result string
* \param states SearchState array to be debugged
* \param verbose whether to enable more verbose debug info
* \return the concatenated debug string
*/
std
::
string
JoinStatesDebugString
(
const
std
::
string
&
title
,
const
std
::
vector
<
SearchState
>&
states
,
bool
verbose
=
false
);
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/search_state_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/context.h"
namespace
cinn
{
namespace
auto_schedule
{
TEST
(
TestSearchState
,
SearchStateHash_Equal
)
{
Target
target
=
common
::
DefaultHostTarget
();
ir
::
Expr
M
(
32
);
ir
::
Expr
N
(
32
);
lang
::
Placeholder
<
float
>
A
(
"A"
,
{
M
,
N
});
ir
::
Tensor
B
=
lang
::
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
,
j
)
+
ir
::
Expr
(
2.
f
);
},
"B"
);
ir
::
Tensor
C
=
lang
::
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
,
j
)
+
B
(
i
,
j
);
},
"C"
);
cinn
::
common
::
Context
::
Global
().
ResetNameId
();
auto
a_plus_const_funcs_1
=
lang
::
LowerVec
(
"A_plus_const"
,
poly
::
CreateStages
({
A
,
B
}),
{
A
,
B
},
{},
{},
nullptr
,
target
,
true
);
cinn
::
common
::
Context
::
Global
().
ResetNameId
();
auto
a_plus_const_funcs_2
=
lang
::
LowerVec
(
"A_plus_const"
,
poly
::
CreateStages
({
A
,
B
}),
{
A
,
B
},
{},
{},
nullptr
,
target
,
true
);
cinn
::
common
::
Context
::
Global
().
ResetNameId
();
auto
a_plus_b_funcs
=
lang
::
LowerVec
(
"A_plus_B"
,
poly
::
CreateStages
({
A
,
C
}),
{
A
,
C
},
{},
{},
nullptr
,
target
,
true
);
std
::
string
a_plus_const_funcs_1_str
=
R"ROC(function A_plus_const (_A, _B)
{
ScheduleBlock(root)
{
serial for (i, 0, 32)
{
serial for (j, 0, 32)
{
ScheduleBlock(B)
{
i0, i1 = axis.bind(i, j)
B[i0, i1] = (A[i0, i1] + 2.00000000f)
}
}
}
}
})ROC"
;
std
::
string
a_plus_const_funcs_2_str
=
R"ROC(function A_plus_const (_A, _B)
{
ScheduleBlock(root)
{
serial for (i, 0, 32)
{
serial for (j, 0, 32)
{
ScheduleBlock(B)
{
i0, i1 = axis.bind(i, j)
B[i0, i1] = (A[i0, i1] + 2.00000000f)
}
}
}
}
})ROC"
;
std
::
string
a_plus_b_funcs_str
=
R"ROC(function A_plus_B (_A, _C)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
serial for (j, 0, 32)
{
ScheduleBlock(B)
{
i0, i1 = axis.bind(i, j)
B[i0, i1] = (A[i0, i1] + 2.00000000f)
}
}
}
serial for (i, 0, 32)
{
serial for (j, 0, 32)
{
ScheduleBlock(C)
{
i0_0, i1_0 = axis.bind(i, j)
C[i0_0, i1_0] = (A[i0_0, i1_0] + B[i0_0, i1_0])
}
}
}
}
}
})ROC"
;
ASSERT_EQ
(
a_plus_const_funcs_1
.
size
(),
1
);
EXPECT_EQ
(
a_plus_const_funcs_1_str
,
utils
::
GetStreamCnt
(
a_plus_const_funcs_1
.
front
()));
ASSERT_EQ
(
a_plus_const_funcs_2
.
size
(),
1
);
EXPECT_EQ
(
a_plus_const_funcs_2_str
,
utils
::
GetStreamCnt
(
a_plus_const_funcs_2
.
front
()));
ASSERT_EQ
(
a_plus_b_funcs
.
size
(),
1
);
EXPECT_EQ
(
a_plus_b_funcs_str
,
utils
::
GetStreamCnt
(
a_plus_b_funcs
.
front
()));
SearchState
a_plus_const_state1
(
ir
::
IRSchedule
(
ir
::
ModuleExpr
({
a_plus_const_funcs_1
.
front
()
->
body
})));
SearchState
a_plus_const_state2
(
ir
::
IRSchedule
(
ir
::
ModuleExpr
({
a_plus_const_funcs_2
.
front
()
->
body
})));
SearchState
a_plus_b_state
(
ir
::
IRSchedule
(
ir
::
ModuleExpr
({
a_plus_b_funcs
.
front
()
->
body
})));
SearchStateHash
hash_functor
;
SearchStateEqual
equal_functor
;
ASSERT_EQ
(
hash_functor
(
a_plus_const_state1
),
hash_functor
(
a_plus_const_state2
));
ASSERT_TRUE
(
equal_functor
(
a_plus_const_state1
,
a_plus_const_state2
));
ASSERT_NE
(
hash_functor
(
a_plus_const_state1
),
hash_functor
(
a_plus_b_state
));
ASSERT_FALSE
(
equal_functor
(
a_plus_const_state1
,
a_plus_b_state
));
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt
0 → 100644
View file @
992bec46
add_subdirectory
(
mutate_rule
)
core_gather_headers
()
gather_srcs
(
cinnapi_src SRCS evolutionary_search.cc
)
cinn_cc_test
(
test_evolutionary_search SRCS evolutionary_search_test.cc DEPS
cinncore test_program_builder
)
paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h"
#include <glog/logging.h>
#include <algorithm>
#include <cstdlib>
#include <functional>
#include <limits>
#include <memory>
#include <utility>
#include "paddle/cinn/auto_schedule/database/database.h"
#include "paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h"
#include "paddle/cinn/auto_schedule/search_space/search_space.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/utils/multi_threading.h"
#include "paddle/cinn/utils/sized_multi_set.h"
#include "paddle/cinn/utils/string.h"
DECLARE_bool
(
auto_schedule_use_cost_model
);
namespace
cinn
{
namespace
auto_schedule
{
EvolutionarySearch
::
EvolutionarySearch
(
const
TuneTask
&
tune_task
,
const
ExprCostModel
&
cost_model
,
Database
*
database
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
,
const
std
::
vector
<
std
::
tuple
<
std
::
string
,
double
>>&
mutate_rules
)
:
tune_task_
(
tune_task
),
cost_model_
(
cost_model
),
database_
(
database
),
rand_seed_
(
utils
::
LinearRandomEngine
::
NormalizeState
(
rand_seed
)),
mutators_
(
mutate_rules
)
{
search_space_
=
std
::
make_unique
<
SearchSpace
>
(
tune_task
,
utils
::
ForkRandomState
(
&
rand_seed_
));
if
(
mutators_
.
empty
())
{
mutators_
.
push_back
(
std
::
make_tuple
(
"mutate_tile_size"
,
1.0
));
}
double
accum_weight
=
0.0
;
for
(
const
auto
&
mutator
:
mutators_
)
{
if
(
std
::
get
<
1
>
(
mutator
)
>
0
)
{
accum_weight
+=
std
::
get
<
1
>
(
mutator
);
weighted_mutators_
.
insert
(
std
::
make_pair
(
accum_weight
,
MutateRule
::
Make
(
std
::
get
<
0
>
(
mutator
))));
}
}
post_schedule_rules_
.
emplace_back
(
new
CooperativeProcess
);
}
EvolutionarySearch
::~
EvolutionarySearch
()
{}
SearchState
EvolutionarySearch
::
SearchModuleExpr
(
const
TuningOptions
&
options
)
{
return
SearchModuleExprBests
(
options
)[
0
];
}
std
::
vector
<
SearchState
>
EvolutionarySearch
::
SearchModuleExprBests
(
const
TuningOptions
&
options
)
{
VLOG
(
4
)
<<
"start SearchModuleExprBests with initial statistics: "
"visited_candidates size="
<<
visited_candidates_
.
size
();
std
::
vector
<
SearchState
>
init_population
;
std
::
vector
<
SearchState
>
topk_from_database
=
GetTopKCandidatesFromDatabase
(
options
.
evolution_pick_database_topk
);
VLOG
(
4
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::GetTopKCandidatesFromDatabase"
,
topk_from_database
,
/*verbose=*/
VLOG_IS_ON
(
5
));
int
init_num
=
options
.
evolution_init_population_num
-
topk_from_database
.
size
();
std
::
vector
<
SearchState
>
init_sketch
=
InitSketch
(
init_num
,
"rule_prune"
);
VLOG
(
4
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::InitSketch"
,
init_sketch
,
/*verbose=*/
VLOG_IS_ON
(
5
));
init_population
.
insert
(
init_population
.
end
(),
topk_from_database
.
begin
(),
topk_from_database
.
end
());
init_population
.
insert
(
init_population
.
end
(),
init_sketch
.
begin
(),
init_sketch
.
end
());
std
::
vector
<
SearchState
>
picked_bests
=
Evolve
(
init_population
,
options
.
evolution_cross_over_num
,
options
.
num_samples_per_iteration
);
VLOG
(
4
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::Evolve"
,
picked_bests
,
/*verbose=*/
VLOG_IS_ON
(
5
));
return
picked_bests
;
}
std
::
vector
<
SearchState
>
EvolutionarySearch
::
SearchModuleExprEpsGreedy
(
const
TuningOptions
&
options
)
{
std
::
vector
<
SearchState
>
picked_bests
=
SearchModuleExprBests
(
options
);
int
random_num
=
options
.
evolution_init_population_num
-
options
.
evolution_pick_database_topk
;
auto
results
=
PickNextGenerationEpsGreedy
(
picked_bests
,
InitSketch
(
random_num
,
"random_prune"
),
options
.
num_samples_per_iteration
,
options
.
evolution_eps_greedy
);
VLOG
(
4
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::PickNextGenerationEpsGreedy"
,
results
,
/*verbose=*/
VLOG_IS_ON
(
5
));
return
results
;
}
std
::
vector
<
SearchState
>
EvolutionarySearch
::
GetTopKCandidatesFromDatabase
(
int
topk
)
{
std
::
vector
<
SearchState
>
results
;
const
auto
&
task_key
=
tune_task_
.
serialized_key
;
auto
records
=
database_
->
GetTopK
(
task_key
,
topk
);
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
for
(
auto
&&
record
:
records
)
{
ir
::
IRSchedule
ir_sch
(
optim
::
IRCopy
(
task_registry
->
Get
(
task_key
)
->
module_expr
),
utils
::
ForkRandomState
(
&
rand_seed_
));
ir
::
ScheduleDesc
::
ReplayWithProto
(
record
.
trace
,
&
ir_sch
);
results
.
emplace_back
(
SearchState
(
std
::
move
(
ir_sch
),
record
.
predicted_cost
));
}
return
results
;
}
void
ApplyPostScheduleRules
(
ir
::
IRSchedule
*
schedule
,
const
std
::
vector
<
std
::
unique_ptr
<
PostScheduleRule
>>&
post_schedule_rules
)
{
schedule
->
TagPostSchedule
();
for
(
const
auto
&
post_rule
:
post_schedule_rules
)
{
post_rule
->
Apply
(
schedule
);
}
}
std
::
vector
<
SearchState
>
EvolutionarySearch
::
InitSketch
(
int
num
,
const
std
::
string
&
strategy
)
{
VLOG
(
4
)
<<
"InitSketch with num:"
<<
num
<<
", strategy: "
<<
strategy
;
std
::
vector
<
SearchState
>
states
=
search_space_
->
GenerateSketches
(
num
,
strategy
);
auto
post_schedule_fn
=
[
this
,
&
states
](
int
index
)
{
ApplyPostScheduleRules
(
&
states
[
index
]
->
ir_schedule
,
post_schedule_rules_
);
};
utils
::
parallel_run
(
post_schedule_fn
,
utils
::
SequenceDispatcher
(
0
,
states
.
size
()),
states
.
size
());
return
states
;
}
SearchState
EvolutionarySearch
::
CrossOver
(
const
SearchState
&
state1
,
const
SearchState
&
state2
)
{
// TODO(CtfGo): tracing CrossOver with IRSchedule
std
::
vector
<
ir
::
Expr
>
cross_over_exprs
;
std
::
vector
<
ir
::
Expr
>
father_exprs
=
state1
->
ir_schedule
.
GetModule
().
GetExprs
();
std
::
vector
<
ir
::
Expr
>
mother_exprs
=
state2
->
ir_schedule
.
GetModule
().
GetExprs
();
CHECK_EQ
(
father_exprs
.
size
(),
mother_exprs
.
size
())
<<
"CrossOver ModuleExpr in EvolutionarySearch must have same number of "
"AST"
;
for
(
size_t
i
=
0
;
i
<
father_exprs
.
size
();
++
i
)
{
if
(
utils
::
SampleUniformInt
(
0
,
2
,
&
rand_seed_
)
==
0
)
{
cross_over_exprs
.
push_back
(
optim
::
IRCopy
(
father_exprs
[
i
]));
}
else
{
cross_over_exprs
.
push_back
(
optim
::
IRCopy
(
mother_exprs
[
i
]));
}
}
auto
res
=
SearchState
(
ir
::
IRSchedule
(
ir
::
ModuleExpr
(
cross_over_exprs
),
utils
::
ForkRandomState
(
&
rand_seed_
)));
if
(
FLAGS_auto_schedule_use_cost_model
)
{
res
->
predicted_cost
=
cost_model_
.
Predict
(
res
->
ir_schedule
.
GetModule
(),
tune_task_
.
target
);
}
VLOG
(
5
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::CrossOver"
,
{
state1
,
state2
,
res
},
/*verbose=*/
VLOG_IS_ON
(
6
));
return
res
;
}
SearchState
EvolutionarySearch
::
Mutate
(
const
SearchState
&
state
,
utils
::
LinearRandomEngine
::
StateType
*
rand_seed
)
{
CHECK_GT
(
weighted_mutators_
.
size
(),
0
)
<<
"There is no mutate rule can be applied."
;
double
accu_weight
=
(
weighted_mutators_
.
rbegin
())
->
first
;
CHECK_GT
(
accu_weight
,
0
)
<<
"The accumulate weight must be greater than 0."
;
// sample a mutate rule
double
sample_weight
=
utils
::
SampleUniformDouble
(
0
,
accu_weight
,
rand_seed
);
auto
sampled_iter
=
weighted_mutators_
.
upper_bound
(
sample_weight
);
MutateRule
*
mutator
=
sampled_iter
->
second
.
get
();
CHECK
(
mutator
)
<<
"mutator not defined"
;
// apply mutation on the trace of SearchState
auto
trace
=
state
->
ir_schedule
.
GetTraceDesc
();
auto
new_trace
=
mutator
->
Apply
(
trace
,
rand_seed
);
// replay the mutated trace on original ModuleExpr to generate a new
// ir_schedule
const
auto
&
task_key
=
tune_task_
.
serialized_key
;
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
ir
::
IRSchedule
new_ir_sch
(
optim
::
IRCopy
(
task_registry
->
Get
(
task_key
)
->
module_expr
),
utils
::
ForkRandomState
(
rand_seed
));
new_trace
.
Replay
(
&
new_ir_sch
,
true
);
ApplyPostScheduleRules
(
&
new_ir_sch
,
post_schedule_rules_
);
auto
res
=
SearchState
(
std
::
move
(
new_ir_sch
));
VLOG
(
5
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::Mutate"
,
{
state
,
res
},
/*verbose=*/
VLOG_IS_ON
(
6
));
return
res
;
}
std
::
vector
<
SearchState
>
EvolutionarySearch
::
Evolve
(
const
std
::
vector
<
SearchState
>&
population
,
int
cross_over_num
,
int
ret_num
)
{
VLOG
(
4
)
<<
utils
::
StringFormat
(
"Evolve with population size=%lu,cross_over_num:%lu,ret_num:%lu"
,
population
.
size
(),
cross_over_num
,
ret_num
);
int
generation_num
=
population
.
size
();
if
(
generation_num
==
0
)
{
return
std
::
vector
<
SearchState
>
();
}
// init evolution
std
::
vector
<
SearchState
>
evolution
(
population
);
for
(
SearchState
&
search_state
:
evolution
)
{
if
(
search_state
->
predicted_cost
==
SearchState
::
NOT_INIT_COST
&&
FLAGS_auto_schedule_use_cost_model
)
{
search_state
->
predicted_cost
=
cost_model_
.
Predict
(
search_state
->
ir_schedule
.
GetModule
(),
tune_task_
.
target
);
}
}
VLOG
(
4
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::Evolve: Init evolution:"
,
evolution
,
/*verbose=*/
VLOG_IS_ON
(
5
));
// cross over
for
(
int
i
=
0
;
i
<
cross_over_num
;
++
i
)
{
int
first_rand_idx
=
utils
::
SampleUniformInt
(
0
,
generation_num
,
&
rand_seed_
);
int
second_rand_idx
=
utils
::
SampleUniformInt
(
0
,
generation_num
,
&
rand_seed_
);
while
(
first_rand_idx
==
second_rand_idx
)
{
second_rand_idx
=
utils
::
SampleUniformInt
(
0
,
generation_num
,
&
rand_seed_
);
}
evolution
.
push_back
(
CrossOver
(
population
[
first_rand_idx
],
population
[
second_rand_idx
]));
}
VLOG
(
4
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::Evolve: after CrossOver evolution:"
,
evolution
,
/*verbose=*/
VLOG_IS_ON
(
5
));
// mutate
std
::
vector
<
SearchState
>
mutated_individuals
(
evolution
.
size
());
std
::
vector
<
utils
::
LinearRandomEngine
::
StateType
>
rand_seeds
(
evolution
.
size
());
for
(
int
i
=
0
;
i
<
rand_seeds
.
size
();
++
i
)
{
rand_seeds
[
i
]
=
utils
::
ForkRandomState
(
&
rand_seed_
);
}
auto
mutate_fn
=
[
this
,
&
evolution
,
&
mutated_individuals
,
&
rand_seeds
](
int
index
)
{
mutated_individuals
[
index
]
=
Mutate
(
evolution
[
index
],
&
rand_seeds
[
index
]);
};
utils
::
parallel_run
(
mutate_fn
,
utils
::
SequenceDispatcher
(
0
,
evolution
.
size
()),
evolution
.
size
());
if
(
FLAGS_auto_schedule_use_cost_model
)
{
for
(
size_t
i
=
0
;
i
<
mutated_individuals
.
size
();
++
i
)
{
mutated_individuals
[
i
]
->
predicted_cost
=
cost_model_
.
Predict
(
mutated_individuals
[
i
]
->
ir_schedule
.
GetModule
(),
tune_task_
.
target
);
}
}
VLOG
(
4
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::Evolve: mutated individuals:"
,
mutated_individuals
,
/*verbose=*/
VLOG_IS_ON
(
5
));
// select top ret_num with predicted cost
utils
::
SizedMultiSet
<
SearchState
>
evolution_with_cost
(
ret_num
);
for
(
size_t
i
=
0
;
i
<
evolution
.
size
();
++
i
)
{
evolution_with_cost
.
Push
(
evolution
[
i
]);
}
for
(
size_t
i
=
0
;
i
<
mutated_individuals
.
size
();
++
i
)
{
evolution_with_cost
.
Push
(
mutated_individuals
[
i
]);
}
auto
selected_individuals
=
evolution_with_cost
.
ReturnAsContainer
<
std
::
vector
<
SearchState
>>
();
VLOG
(
4
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::Evolve: selected individuals:"
,
selected_individuals
,
/*verbose=*/
VLOG_IS_ON
(
5
));
return
selected_individuals
;
}
std
::
vector
<
SearchState
>
EvolutionarySearch
::
PickNextGenerationEpsGreedy
(
const
std
::
vector
<
SearchState
>&
picked_bests
,
const
std
::
vector
<
SearchState
>&
random_init
,
int
num
,
float
eps_greedy
)
{
int
num_rands
=
num
*
eps_greedy
;
int
num_bests
=
num
-
num_rands
;
std
::
vector
<
SearchState
>
result
;
SearchState
selected
;
int
deduplicated_cnt
=
0
;
int
best_idx
=
0
;
int
rand_idx
=
0
;
while
(
result
.
size
()
<
num
)
{
if
(
result
.
size
()
<
num_bests
&&
best_idx
<
picked_bests
.
size
())
{
selected
=
picked_bests
[
best_idx
];
++
best_idx
;
}
else
if
(
rand_idx
<
random_init
.
size
())
{
selected
=
random_init
[
rand_idx
];
++
rand_idx
;
}
else
if
(
best_idx
<
picked_bests
.
size
())
{
selected
=
picked_bests
[
best_idx
];
++
best_idx
;
}
else
{
break
;
}
if
(
!
visited_candidates_
.
count
(
selected
))
{
// deduplicate
VLOG
(
4
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::PickNextGenerationEpsGreedy-Selected"
,
{
selected
},
/*verbose=*/
VLOG_IS_ON
(
5
));
visited_candidates_
.
insert
(
selected
);
result
.
push_back
(
selected
);
}
else
{
++
deduplicated_cnt
;
VLOG
(
4
)
<<
JoinStatesDebugString
(
"EvolutionarySearch::PickNextGenerationEpsGreedy-Deduplicated"
,
{
selected
},
/*verbose=*/
VLOG_IS_ON
(
5
));
}
}
VLOG
(
4
)
<<
utils
::
StringFormat
(
"PickNextGenerationEpsGreedy: picked_bests size=%lu,random_init "
"size=%lu,num=%d,"
"eps_greedy=%f,deduplicated_cnt=%d,result size=%lu"
,
picked_bests
.
size
(),
random_init
.
size
(),
num
,
eps_greedy
,
deduplicated_cnt
,
result
.
size
());
return
result
;
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/cinn/auto_schedule/cost_model/expr_cost_model.h"
#include "paddle/cinn/auto_schedule/database/database.h"
#include "paddle/cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h"
#include "paddle/cinn/auto_schedule/search_space/search_space.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
/**
* Class implement the evolutionary search on ModuleExpr search space.
*/
class
EvolutionarySearch
{
public:
/**
* constructor with TuneTask.
*
* @param tune_task: the TuneTask this class works on. This class doesn't
* take ownership of the pointer.
*/
EvolutionarySearch
(
const
TuneTask
&
tune_task
,
const
ExprCostModel
&
cost_model
,
Database
*
database
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
-
1
,
const
std
::
vector
<
std
::
tuple
<
std
::
string
,
double
>>&
mutate_rules
=
{});
/**
* Destructor
*/
~
EvolutionarySearch
();
/**
* Run the evolutionary search for one iteration.
*
* @return SearchState containing the best ir::ModuleExpr searched in this
* iteration
*/
SearchState
SearchModuleExpr
(
const
TuningOptions
&
options
);
/**
* Run the evolutionary search for one iteration.
*
* @return SearchState(s) containing best ir::ModuleExpr(s) searched in this
* iteration
*/
std
::
vector
<
SearchState
>
SearchModuleExprBests
(
const
TuningOptions
&
options
);
/**
* Run the evolutionary search for one iteration, but since evolutionary
* search with cost model may not be accurate, this method picks
* "eps * total_return_size" random samples along with those best
* ir::ModuleExpr's searched in this iteration.
*
* @return SearchSpace containing those best ir::ModuleExpr's searched
* in this iteration and some random samples. There are
* "eps * total_return_size" random samples and
* "(1 - eps) * total_return_size" best searched samples.
*/
std
::
vector
<
SearchState
>
SearchModuleExprEpsGreedy
(
const
TuningOptions
&
options
);
#ifdef CINN_WITH_TEST
/**
* Method only be called during testing. It is used to set mock search
* space.
*
* @param search_space: the mock search space, note that EvolutionarySearch
* takes the ownership.
*/
void
SetSearchSpace
(
SearchSpace
*
search_space
)
{
search_space_
.
reset
(
search_space
);
}
// Method only be called during testing, it is a wrapper of private method
// InitSketch().
std
::
vector
<
SearchState
>
TestInitSketch
(
int
num
,
const
std
::
string
&
strategy
)
{
return
InitSketch
(
num
,
strategy
);
}
// Method only be called during testing, it is a wrapper of private method
// Evolve().
std
::
vector
<
SearchState
>
TestEvolve
(
const
std
::
vector
<
SearchState
>&
population
,
int
cross_over_num
,
int
ret_num
)
{
return
Evolve
(
population
,
cross_over_num
,
ret_num
);
}
#endif
private:
std
::
vector
<
SearchState
>
GetTopKCandidatesFromDatabase
(
int
topk
);
/**
* \brief Generate sketch as initial population of evolutionary search.
* @param num The number of sketches to generate.
* @param strategy The strategy to generate sketches,
* Current optional strategies are "rule_prune" or "random_prune" or
* "random".
* - "rule_prune": will use rules to prune and generate sketches as
* efficiently as possible.
* - "random_prune": will use the new interface ApplySketchRules() to simulate
* the random generation of sketches, and supports the function of a rule
* returning multiple SearchStates and random pruning by probability.
* - "random": will randomly select a block and a rule to apply and repeat
* this step several times, however, each rule can only be used on one
* SearchState at most once.
* @return Generated sketches.
*/
std
::
vector
<
SearchState
>
InitSketch
(
int
num
,
const
std
::
string
&
strategy
);
SearchState
Mutate
(
const
SearchState
&
state
,
utils
::
LinearRandomEngine
::
StateType
*
rand_seed
);
SearchState
CrossOver
(
const
SearchState
&
state1
,
const
SearchState
&
state2
);
std
::
vector
<
SearchState
>
Evolve
(
const
std
::
vector
<
SearchState
>&
population
,
int
cross_over_num
,
int
ret_num
);
std
::
vector
<
SearchState
>
PickNextGenerationEpsGreedy
(
const
std
::
vector
<
SearchState
>&
population
,
const
std
::
vector
<
SearchState
>&
random_init
,
int
num
,
float
eps_greedy
);
private:
std
::
unique_ptr
<
SearchSpace
>
search_space_
;
const
TuneTask
&
tune_task_
;
const
ExprCostModel
&
cost_model_
;
// not owned
Database
*
database_
;
// not owned
// used to duplicate states with the same structural IR
std
::
unordered_set
<
SearchState
,
SearchStateHash
,
SearchStateEqual
>
visited_candidates_
;
// mutate rule names and their weights
std
::
vector
<
std
::
tuple
<
std
::
string
,
double
>>
mutators_
;
// mutate rules, the key is the accumulate weight of each mutate rule
std
::
map
<
double
,
std
::
unique_ptr
<
MutateRule
>>
weighted_mutators_
;
// schedule rules used after mutation
std
::
vector
<
std
::
unique_ptr
<
PostScheduleRule
>>
post_schedule_rules_
;
utils
::
LinearRandomEngine
::
StateType
rand_seed_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include "paddle/cinn/auto_schedule/cost_model/expr_cost_model.h"
#include "paddle/cinn/auto_schedule/database/database.h"
#include "paddle/cinn/auto_schedule/search_space/search_space.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/auto_schedule/task/task_creator.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "test/cpp/cinn/program_builder.h"
namespace
cinn
{
namespace
auto_schedule
{
std
::
vector
<
TuneTask
>
CreateTasks
(
const
frontend
::
Program
&
program
,
const
Target
&
target
)
{
auto
graph
=
std
::
make_shared
<
hlir
::
framework
::
Graph
>
(
program
,
target
);
TaskCreator
task_creator
;
auto
tasks
=
task_creator
.
CreateTuneTaskOpLevel
(
graph
.
get
());
const
auto
&
dtype_dict
=
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
common
::
Type
>>
(
"inferdtype"
);
const
auto
&
shape_dict
=
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
auto
op_lowerer
=
std
::
make_unique
<
hlir
::
framework
::
OpLowerer
>
(
dtype_dict
,
shape_dict
,
target
);
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
for
(
auto
i
=
0
;
i
<
tasks
.
size
();
++
i
)
{
tasks
[
i
].
Initialize
(
shape_dict
,
dtype_dict
,
op_lowerer
.
get
());
task_registry
->
Regist
(
tasks
[
i
].
serialized_key
,
ir
::
ModuleExpr
(
tasks
[
i
].
GetLoweredFuncBodyExprs
()));
}
return
tasks
;
}
/**
* A mock search space is only used for test. It creates integer ir::Expr from
* 0, -1, -2, ... and set the cost value same as the integer value.
*
* So evolutionary search should be able to find the minimal ModuleExpr with
* smallest ir::Expr. This file tests it.
*/
class
MockSearchSpace
:
public
SearchSpace
{
public:
explicit
MockSearchSpace
(
const
TuneTask
&
tune_task
)
:
SearchSpace
(
tune_task
)
{}
int
GetMinExprValue
()
const
{
return
min_expr_value_
;
}
int
GetModuleExprSize
()
const
{
return
module_expr_size_
;
}
std
::
vector
<
SearchState
>
GenerateSketches
(
int
num
,
const
std
::
string
&
strategy
)
override
{
std
::
vector
<
SearchState
>
ret
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
std
::
vector
<
ir
::
Expr
>
exprs
;
for
(
int
j
=
0
;
j
<
module_expr_size_
;
++
j
)
{
exprs
.
push_back
(
ir
::
Expr
(
-
i
));
}
min_expr_value_
=
-
i
;
ret
.
push_back
(
SearchState
(
ir
::
IRSchedule
(
ir
::
ModuleExpr
(
exprs
))));
}
return
ret
;
}
private:
int
module_expr_size_
=
10
;
int
min_expr_value_
=
0
;
};
class
MockCostModel
:
public
ExprCostModel
{
float
Predict
(
const
ir
::
ModuleExpr
&
sample
,
const
common
::
Target
&
target
)
const
override
{
float
cost
=
0.0
f
;
std
::
vector
<
ir
::
Expr
>
exprs
=
sample
.
GetExprs
();
for
(
const
ir
::
Expr
&
expr
:
exprs
)
{
if
(
expr
.
as_int32
())
{
cost
+=
static_cast
<
float
>
((
expr
.
as_int32
()));
}
}
return
cost
;
}
};
TEST
(
EvolutionarySearch
,
GetOneBest
)
{
TuneTask
mock_tune_task
;
mock_tune_task
.
serialized_key
=
"mock_task"
;
mock_tune_task
.
target
=
common
::
DefaultTarget
();
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
task_registry
->
Regist
(
mock_tune_task
.
serialized_key
,
ir
::
ModuleExpr
({
ir
::
Expr
(
0
)}));
MockCostModel
cost_model
;
TuningOptions
options
;
Database
db
(
2
);
EvolutionarySearch
evolutionary_search
(
mock_tune_task
,
cost_model
,
&
db
);
MockSearchSpace
*
mock_search_space
=
new
MockSearchSpace
(
mock_tune_task
);
// Ownership is transferred so don't delete mock_search_space
evolutionary_search
.
SetSearchSpace
(
mock_search_space
);
SearchState
best_state
=
evolutionary_search
.
SearchModuleExpr
(
options
);
std
::
vector
<
ir
::
Expr
>
exprs
=
best_state
->
ir_schedule
.
GetModule
().
GetExprs
();
EXPECT_GE
(
exprs
.
size
(),
1UL
);
for
(
const
ir
::
Expr
&
e
:
exprs
)
{
EXPECT_EQ
(
e
.
as_int32
(),
mock_search_space
->
GetMinExprValue
());
}
}
TEST
(
EvolutionarySearch
,
GetEpsGreedy
)
{
TuneTask
mock_tune_task
;
mock_tune_task
.
serialized_key
=
"mock_task"
;
mock_tune_task
.
target
=
common
::
DefaultTarget
();
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
task_registry
->
Regist
(
mock_tune_task
.
serialized_key
,
ir
::
ModuleExpr
({
ir
::
Expr
(
0
)}));
ExprCostModel
cost_model
;
TuningOptions
options
;
Database
db
(
2
);
EvolutionarySearch
evolutionary_search
(
mock_tune_task
,
cost_model
,
&
db
);
MockSearchSpace
*
mock_search_space
=
new
MockSearchSpace
(
mock_tune_task
);
// Ownership is transferred so don't delete mock_search_space
evolutionary_search
.
SetSearchSpace
(
mock_search_space
);
std
::
vector
<
SearchState
>
search_states
=
evolutionary_search
.
SearchModuleExprEpsGreedy
(
options
);
EXPECT_GE
(
search_states
.
size
(),
1UL
);
size_t
expr_size
=
static_cast
<
size_t
>
(
mock_search_space
->
GetModuleExprSize
());
for
(
const
SearchState
&
state
:
search_states
)
{
EXPECT_EQ
(
state
->
ir_schedule
.
GetModule
().
GetExprs
().
size
(),
expr_size
);
}
}
TEST
(
EvolutionarySearch
,
Evolve
)
{
auto
target
=
common
::
DefaultNVGPUTarget
();
auto
tasks
=
CreateTasks
(
tests
::
OpBuilder
(
"matmul"
).
Build
({{
"X"
,
{
32
,
32
}},
{
"Y"
,
{
32
,
32
}}}),
target
);
CHECK_EQ
(
tasks
.
size
(),
1
);
ExprCostModel
cost_model
;
std
::
vector
<
const
ir
::
ModuleExpr
*>
cost_model_samples
(
1
);
std
::
vector
<
float
>
cost_model_labels
(
1
);
for
(
size_t
i
=
0
;
i
<
2
;
++
i
)
{
ir
::
ModuleExpr
me
({
ir
::
Expr
(
tasks
[
0
].
lowered_funcs
[
0
])});
cost_model_samples
[
0
]
=
&
me
;
cost_model_labels
[
0
]
=
i
+
10
;
cost_model
.
Update
(
cost_model_samples
,
cost_model_labels
,
target
);
}
Database
db
(
2
);
TuningOptions
options
;
options
.
evolution_pick_database_topk
=
0
;
EvolutionarySearch
evolutionary_search
(
tasks
[
0
],
cost_model
,
&
db
);
int
num_population
=
10
;
std
::
vector
<
SearchState
>
init_sketch
=
evolutionary_search
.
TestInitSketch
(
num_population
,
"rule_prune"
);
for
(
int
i
=
0
;
i
<
num_population
;
++
i
)
{
ir
::
ModuleExpr
me
(
init_sketch
[
i
]
->
ir_schedule
.
GetModule
());
cost_model_samples
[
0
]
=
&
me
;
cost_model_labels
[
0
]
=
i
;
cost_model
.
Update
(
cost_model_samples
,
cost_model_labels
,
target
);
}
VLOG
(
6
)
<<
"init sketch costs:"
;
for
(
auto
s
:
init_sketch
)
{
VLOG
(
6
)
<<
"cost = "
<<
s
->
predicted_cost
;
}
std
::
vector
<
SearchState
>*
population_pre_ptr
=
&
init_sketch
,
*
population_next_ptr
;
std
::
vector
<
SearchState
>
population
;
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
population
=
evolutionary_search
.
TestEvolve
(
*
population_pre_ptr
,
/*cross_over_num*/
0
,
/*ret_num*/
10
);
population_next_ptr
=
&
population
;
VLOG
(
6
)
<<
"population["
<<
i
+
1
<<
"] costs:"
;
double
total_cost_pre
=
0.0
,
total_cost_next
=
0.0
;
for
(
auto
s
:
*
population_pre_ptr
)
{
total_cost_pre
+=
s
->
predicted_cost
;
}
for
(
auto
s
:
*
population_next_ptr
)
{
total_cost_next
+=
s
->
predicted_cost
;
VLOG
(
6
)
<<
"cost = "
<<
s
->
predicted_cost
;
}
VLOG
(
6
)
<<
"total_cost_next = "
<<
total_cost_next
;
CHECK_LE
(
total_cost_next
,
total_cost_pre
);
std
::
swap
(
population_pre_ptr
,
population_next_ptr
);
}
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_strategy/mutate_rule/CMakeLists.txt
0 → 100644
View file @
992bec46
core_gather_headers
()
gather_srcs
(
cinnapi_src SRCS mutate_rule.cc mutate_tile_size.cc
)
cinn_cc_test
(
test_mutate_tile_size SRCS mutate_tile_size_test.cc DEPS cinncore
)
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h"
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h"
namespace
cinn
{
namespace
auto_schedule
{
std
::
unique_ptr
<
MutateRule
>
MutateRule
::
Make
(
const
std
::
string
&
name
)
{
if
(
name
==
"mutate_tile_size"
)
{
return
std
::
make_unique
<
MutateTileSize
>
();
}
else
{
LOG
(
FATAL
)
<<
"MutateRule "
<<
name
<<
" is not supported."
;
}
return
nullptr
;
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h
0 → 100644
View file @
992bec46
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/schedule/schedule_desc.h"
#include "paddle/cinn/utils/random_engine.h"
namespace
cinn
{
namespace
auto_schedule
{
/**
* Base class for rules of mutate,
* is used for mutating the trace(ScheduleDesc) to explore the search space.
*/
class
MutateRule
{
public:
MutateRule
()
=
default
;
/**
* @brief Apply the mutate rule to the given trace.
* @param trace The given trace for mutation.
* @param rand_seed The random seed for mutation.
* @return The mutated trace.
*/
virtual
ir
::
ScheduleDesc
Apply
(
const
ir
::
ScheduleDesc
&
trace
,
utils
::
LinearRandomEngine
::
StateType
*
rand_seed
)
=
0
;
/**
* @brief Create a MutateRule with name.
* @param name The name of mutate rule, consisting of lowercase letters and
* underscores
* @return The created MutateRule.
*/
static
std
::
unique_ptr
<
MutateRule
>
Make
(
const
std
::
string
&
name
);
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h"
namespace
cinn
{
namespace
auto_schedule
{
using
::
cinn
::
ir
::
ScheduleDesc
;
using
::
cinn
::
utils
::
LinearRandomEngine
;
using
SampledTile
=
std
::
tuple
<
ScheduleDesc
::
Step
,
std
::
vector
<
int
>
,
int
>
;
static
std
::
vector
<
int
>
Factorize
(
int
n
)
{
std
::
vector
<
int
>
res
;
for
(
int
i
=
1
;
i
*
i
<=
n
;
++
i
)
{
if
(
n
%
i
==
0
)
{
res
.
push_back
(
i
);
if
(
i
*
i
!=
n
)
{
res
.
push_back
(
n
/
i
);
}
}
}
std
::
sort
(
res
.
begin
(),
res
.
end
());
return
res
;
}
std
::
vector
<
SampledTile
>
FindSampledTiles
(
const
ScheduleDesc
&
trace
)
{
std
::
vector
<
SampledTile
>
tiles
;
int
step_idx
=
0
;
for
(
auto
&&
step
:
trace
.
Steps
())
{
if
(
step
.
type
==
"TagPostSchedule"
)
{
break
;
}
if
(
step
.
type
==
"SamplePerfectTile"
)
{
std
::
vector
<
int
>
tile_factors
=
absl
::
get
<
std
::
vector
<
int
>>
(
step
.
attrs
.
at
(
"decision"
));
CHECK
(
tile_factors
.
size
()
>=
2
)
<<
"factors size must be greater equal than 2, which is "
<<
tile_factors
.
size
();
tiles
.
push_back
(
std
::
make_tuple
(
step
,
tile_factors
,
step_idx
));
}
++
step_idx
;
}
return
tiles
;
}
ScheduleDesc
DoMutateTileSize
(
const
ScheduleDesc
&
trace
,
const
SampledTile
&
tile
,
LinearRandomEngine
::
StateType
*
rand_seed
)
{
ScheduleDesc
::
Step
step
=
std
::
get
<
0
>
(
tile
);
std
::
vector
<
int
>
tile_factors
=
std
::
get
<
1
>
(
tile
);
int
split_size
=
tile_factors
.
size
();
// Step 1. Choose 2 loops with index: 'loop_x' and 'loop_y'
int
loop_x
,
loop_y
;
bool
all_one_factors
=
true
;
for
(
int
t
:
tile_factors
)
{
if
(
t
!=
1
)
{
all_one_factors
=
false
;
break
;
}
}
if
(
all_one_factors
)
{
VLOG
(
6
)
<<
"Factors are all 1, unable to mutate, return the original trace"
;
return
trace
;
}
while
(
true
)
{
VLOG
(
6
)
<<
"while (true) loop in DoMutateTileSize"
;
loop_x
=
utils
::
SampleUniformInt
(
0
,
split_size
,
rand_seed
);
if
(
tile_factors
.
at
(
loop_x
)
<=
1
)
{
continue
;
}
loop_y
=
utils
::
SampleUniformInt
(
0
,
split_size
-
1
,
rand_seed
);
if
(
loop_y
>=
loop_x
)
{
++
loop_y
;
}
std
::
vector
<
int
>
optional_factors
=
Factorize
(
tile_factors
.
at
(
loop_x
));
// Step 2. Choose the divisor for mutate.
int
divisor
;
if
(
loop_y
==
split_size
-
1
)
{
int
max_innermost_factor
=
absl
::
get
<
int
>
(
step
.
attrs
.
at
(
"max_innermost_factor"
));
int
max_optional_factor_idx
=
optional_factors
.
size
()
-
1
;
for
(;
max_optional_factor_idx
>
0
;
--
max_optional_factor_idx
)
{
if
(
optional_factors
.
at
(
max_optional_factor_idx
)
*
tile_factors
.
at
(
loop_y
)
<=
max_innermost_factor
)
{
break
;
}
}
if
(
max_optional_factor_idx
==
0
)
{
if
(
split_size
<=
2
)
{
VLOG
(
6
)
<<
"Unable to mutate, return the original trace"
;
return
trace
;
}
continue
;
}
divisor
=
optional_factors
.
at
(
utils
::
SampleUniformInt
(
1
,
max_optional_factor_idx
+
1
,
rand_seed
));
}
else
{
divisor
=
optional_factors
.
at
(
utils
::
SampleUniformInt
(
1
,
optional_factors
.
size
(),
rand_seed
));
}
// Step 3. Determine the new tile value
VLOG
(
6
)
<<
"DoMutateTileSize: divisor = "
<<
divisor
<<
", before mutate:
\n
"
<<
"factors["
<<
loop_x
<<
"] = "
<<
tile_factors
[
loop_x
]
<<
", factors["
<<
loop_y
<<
"] = "
<<
tile_factors
[
loop_y
];
tile_factors
[
loop_x
]
/=
divisor
;
tile_factors
[
loop_y
]
*=
divisor
;
VLOG
(
6
)
<<
"after mutate:
\n
"
<<
"factors["
<<
loop_x
<<
"] = "
<<
tile_factors
[
loop_x
]
<<
", factors["
<<
loop_y
<<
"] = "
<<
tile_factors
[
loop_y
];
// Step 4. Create a new step with new tile values and return the new trace
int
step_idx
=
std
::
get
<
2
>
(
tile
);
return
trace
.
ForkAndUpdate
(
step_idx
,
tile_factors
,
true
);
}
}
ScheduleDesc
MutateTileSize
::
Apply
(
const
ScheduleDesc
&
trace
,
LinearRandomEngine
::
StateType
*
rand_seed
)
{
VLOG
(
6
)
<<
"Start applying MutateTileSize, old trace:
\n
"
<<
trace
.
DebugString
();
std
::
vector
<
ScheduleDesc
::
Step
>
sample_tile_steps
;
std
::
vector
<
std
::
vector
<
int
>>
sample_tile_data
;
auto
sampled_tiles
=
FindSampledTiles
(
trace
);
if
(
sampled_tiles
.
size
()
==
0
)
{
VLOG
(
6
)
<<
"MutateTileSize failed, try other mutate rules."
;
return
trace
;
}
int
sample_step_idx
=
utils
::
SampleUniformInt
(
0
,
sampled_tiles
.
size
(),
rand_seed
);
auto
new_trace
=
DoMutateTileSize
(
trace
,
sampled_tiles
.
at
(
sample_step_idx
),
rand_seed
);
VLOG
(
6
)
<<
"End applying MutateTileSize, new trace:
\n
"
<<
new_trace
.
DebugString
();
return
new_trace
;
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h
0 → 100644
View file @
992bec46
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h"
namespace
cinn
{
namespace
auto_schedule
{
/**
* The rule to mutate tile size, witch will modify the factors of the Split
* primitive.
*/
class
MutateTileSize
:
public
MutateRule
{
public:
MutateTileSize
()
=
default
;
ir
::
ScheduleDesc
Apply
(
const
ir
::
ScheduleDesc
&
trace
,
utils
::
LinearRandomEngine
::
StateType
*
rand_seed
)
override
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
TEST
(
MutateTileSize
,
Basic
)
{
srand
(
0
);
Context
::
Global
().
ResetNameId
();
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
const
int
kSize
=
32
;
Expr
M
(
kSize
);
Expr
N
(
kSize
);
Expr
K
(
kSize
);
Placeholder
<
float
>
A
(
"A"
,
{
M
,
K
});
Placeholder
<
float
>
B
(
"B"
,
{
K
,
N
});
Var
k
(
K
.
as_int32
(),
"reduce_axis_k"
);
ir
::
Tensor
C
=
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
ReduceSum
(
A
(
i
,
k
)
*
B
(
k
,
j
),
{
k
});
},
"C"
);
poly
::
StageMap
stages
=
CreateStages
({
A
,
B
,
C
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"TestMutateTileSize_Basic"
,
stages
,
{
A
,
B
,
C
},
{},
{},
nullptr
,
target
,
true
);
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Original Expr: "
;
VLOG
(
6
)
<<
ast_expr
;
ir
::
ModuleExpr
module_expr
({
ast_expr
});
// We need to fix the seed as a constant to ensure that the result can be
// repeated.
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
123
;
ir
::
IRSchedule
ir_schedule
(
module_expr
,
rand_seed
);
ir
::
IRSchedule
new_ir_schedule
(
ir_schedule
);
// apply schedule
auto
loops
=
ir_schedule
.
GetLoops
(
"C"
);
auto
factors
=
ir_schedule
.
SamplePerfectTile
(
loops
[
0
],
2
,
kSize
);
auto
splited
=
ir_schedule
.
Split
(
loops
[
0
],
factors
);
// apply mutate
MutateTileSize
mutator
;
ir
::
ScheduleDesc
sch_desc
=
mutator
.
Apply
(
ir_schedule
.
GetTraceDesc
(),
&
rand_seed
);
sch_desc
.
Replay
(
&
new_ir_schedule
,
true
);
VLOG
(
6
)
<<
"Expr before mutate tile size:
\n
"
<<
ir_schedule
.
GetModule
().
GetExprs
()[
0
];
VLOG
(
6
)
<<
"Expr after mutate tile size:
\n
"
<<
new_ir_schedule
.
GetModule
().
GetExprs
()[
0
];
std
::
string
target_new_ir
=
R"ROC({
ScheduleBlock(root)
{
serial for (i_1, 0, 2)
{
serial for (i_2, 0, 16)
{
serial for (j, 0, 32)
{
ScheduleBlock(C__reduce_init)
{
i0, i1 = axis.bind(((16 * i_1) + i_2), j)
C__reduce_init[i0, i1] = 0.00000000f
}
serial for (reduce_axis_k, 0, 32)
{
ScheduleBlock(C)
{
i0_0, i1_0, i2 = axis.bind(((16 * i_1) + i_2), j, reduce_axis_k)
C[i0_0, i1_0] = (C[i0_0, i1_0] + (A[i0_0, i2] * B[i2, i1_0]))
}
}
}
}
}
}
})ROC"
;
auto
get_ir_str
=
[](
const
ir
::
IRSchedule
*
ir_sch
)
->
std
::
string
{
std
::
vector
<
ir
::
Expr
>
exprs
=
ir_sch
->
GetModule
().
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
std
::
stringstream
ss
;
ss
<<
exprs
[
0
];
return
ss
.
str
();
};
ASSERT_EQ
(
get_ir_str
(
&
new_ir_schedule
),
target_new_ir
);
std
::
vector
<
int
>
last_tile_factors
=
{
2
,
16
};
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
sch_desc
=
mutator
.
Apply
(
sch_desc
,
&
rand_seed
);
for
(
auto
&&
step
:
sch_desc
.
Steps
())
{
if
(
step
.
type
==
"SamplePerfectTile"
)
{
std
::
vector
<
int
>
tile_factors
=
absl
::
get
<
std
::
vector
<
int
>>
(
step
.
attrs
.
at
(
"decision"
));
ASSERT_EQ
(
tile_factors
.
size
(),
last_tile_factors
.
size
());
ASSERT_NE
(
tile_factors
[
0
],
last_tile_factors
[
0
]);
ASSERT_NE
(
tile_factors
[
1
],
last_tile_factors
[
1
]);
ASSERT_EQ
(
tile_factors
[
0
]
*
tile_factors
[
1
],
kSize
);
last_tile_factors
=
tile_factors
;
}
}
}
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/task/CMakeLists.txt
0 → 100644
View file @
992bec46
core_gather_headers
()
gather_srcs
(
cinnapi_src SRCS task_creator.cc task_optimizer.cc tune_task.cc
)
gather_srcs
(
cinnapi_src SRCS task_creator.cc task_optimizer.cc
)
cinn_cc_test
(
test_task_creator SRCS task_creator_test.cc DEPS cinncore
)
cinn_cc_test
(
test_tune_task SRCS tune_task_test.cc DEPS cinncore
)
cinn_cc_test
(
test_task_registry SRCS task_registry_test.cc DEPS cinncore
)
paddle/cinn/auto_schedule/task/task_creator.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/task/task_creator.h"
#include <glog/logging.h>
#include <memory>
#include <tuple>
#include <vector>
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/pass.h"
namespace
cinn
{
namespace
auto_schedule
{
using
::
cinn
::
common
::
GraphEdge
;
using
::
cinn
::
common
::
GraphNode
;
using
::
cinn
::
hlir
::
framework
::
Graph
;
using
::
cinn
::
hlir
::
framework
::
Node
;
using
::
cinn
::
hlir
::
framework
::
NodeData
;
std
::
vector
<
TuneTask
>
TaskCreator
::
CreateTuneTaskOpLevel
(
Graph
*
graph
)
{
std
::
vector
<
TuneTask
>
ret_tasks
;
const
std
::
vector
<
std
::
shared_ptr
<
Graph
::
Group
>>*
groups
=
&
graph
->
fusion_groups
;
std
::
vector
<
std
::
shared_ptr
<
Graph
::
Group
>>
non_fused_groups
;
// The input graph doesn't run Op Fusion
if
(
graph
->
fusion_groups
.
empty
())
{
hlir
::
framework
::
ApplyPasses
(
graph
,
{
"BuildNonFusedGroupsPass"
});
groups
=
&
graph
->
fusion_groups
;
}
VLOG
(
3
)
<<
"Graph groups size:"
<<
groups
->
size
();
for
(
const
auto
&
sub_graph
:
*
groups
)
{
ret_tasks
.
emplace_back
(
TuneTask
());
ret_tasks
.
back
().
subgraph
=
sub_graph
;
ret_tasks
.
back
().
target
=
graph
->
target_
;
}
return
ret_tasks
;
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/task/task_creator.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/framework/graph.h"
namespace
cinn
{
namespace
auto_schedule
{
/**
* Class to create auto tune task.
*/
class
TaskCreator
{
public:
std
::
vector
<
TuneTask
>
CreateTuneTaskOpLevel
(
hlir
::
framework
::
Graph
*
graph
);
};
}
// namespace auto_schedule
}
// namespace cinn
Prev
1
…
6
7
8
9
10
11
12
13
14
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment