Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
aa98462d
"git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "47039f06979f2455e5e73f8807791d4e6a1c027f"
Unverified
Commit
aa98462d
authored
Oct 12, 2022
by
Yuge Zhang
Committed by
GitHub
Oct 12, 2022
Browse files
Dedup for evolution search (#5092)
parent
bcc640c4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
6 deletions
+36
-6
nni/nas/strategy/evolution.py
nni/nas/strategy/evolution.py
+29
-6
test/algo/nas/test_strategy.py
test/algo/nas/test_strategy.py
+7
-0
No files found.
nni/nas/strategy/evolution.py
View file @
aa98462d
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
import
collections
import
collections
import
dataclasses
import
dataclasses
import
logging
import
logging
import
random
import
random
import
time
import
time
from
typing
import
Deque
from
nni.nas.execution
import
query_available_resources
,
submit_models
from
nni.nas.execution
import
query_available_resources
,
submit_models
from
nni.nas.execution.common
import
ModelStatus
from
nni.nas.execution.common
import
Model
,
ModelStatus
from
.base
import
BaseStrategy
from
.base
import
BaseStrategy
from
.utils
import
dry_run_for_search_space
,
get_targeted_model
,
filter_model
from
.utils
import
dry_run_for_search_space
,
get_targeted_model
,
filter_model
...
@@ -43,6 +46,10 @@ class RegularizedEvolution(BaseStrategy):
...
@@ -43,6 +46,10 @@ class RegularizedEvolution(BaseStrategy):
The number of individuals that should participate in each tournament. Default: 25.
The number of individuals that should participate in each tournament. Default: 25.
mutation_prob : float
mutation_prob : float
Probability that mutation happens in each dim. Default: 0.05
Probability that mutation happens in each dim. Default: 0.05
dedup : bool
Do not try the same configuration twice. Default: true.
dedup_retries : int
If dedup is true, retry the same configuration up to dedup_retries times. Default: 500.
on_failure : str
on_failure : str
Can be one of "ignore" and "worst". If "ignore", simply give up the model and find a new one.
Can be one of "ignore" and "worst". If "ignore", simply give up the model and find a new one.
If "worst", mark the model as -inf (if maximize, inf if minimize), so that the algorithm "learns" to avoid such model.
If "worst", mark the model as -inf (if maximize, inf if minimize), so that the algorithm "learns" to avoid such model.
...
@@ -52,7 +59,7 @@ class RegularizedEvolution(BaseStrategy):
...
@@ -52,7 +59,7 @@ class RegularizedEvolution(BaseStrategy):
"""
"""
def
__init__
(
self
,
optimize_mode
=
'maximize'
,
population_size
=
100
,
sample_size
=
25
,
cycles
=
20000
,
def
__init__
(
self
,
optimize_mode
=
'maximize'
,
population_size
=
100
,
sample_size
=
25
,
cycles
=
20000
,
mutation_prob
=
0.05
,
on_failure
=
'ignore'
,
model_filter
=
None
):
mutation_prob
=
0.05
,
dedup
=
False
,
dedup_retries
=
500
,
on_failure
=
'ignore'
,
model_filter
=
None
):
assert
optimize_mode
in
[
'maximize'
,
'minimize'
]
assert
optimize_mode
in
[
'maximize'
,
'minimize'
]
assert
on_failure
in
[
'ignore'
,
'worst'
]
assert
on_failure
in
[
'ignore'
,
'worst'
]
assert
sample_size
<
population_size
assert
sample_size
<
population_size
...
@@ -61,13 +68,16 @@ class RegularizedEvolution(BaseStrategy):
...
@@ -61,13 +68,16 @@ class RegularizedEvolution(BaseStrategy):
self
.
sample_size
=
sample_size
self
.
sample_size
=
sample_size
self
.
cycles
=
cycles
self
.
cycles
=
cycles
self
.
mutation_prob
=
mutation_prob
self
.
mutation_prob
=
mutation_prob
self
.
dedup
=
dedup
self
.
dedup_retries
=
dedup_retries
self
.
on_failure
=
on_failure
self
.
on_failure
=
on_failure
self
.
_worst
=
float
(
'-inf'
)
if
self
.
optimize_mode
==
'maximize'
else
float
(
'inf'
)
self
.
_worst
=
float
(
'-inf'
)
if
self
.
optimize_mode
==
'maximize'
else
float
(
'inf'
)
self
.
_success_count
=
0
self
.
_success_count
=
0
self
.
_population
=
collections
.
deque
()
self
.
_history_configs
:
list
[
str
]
=
[]
# for dedup. has to be a list because keys are non-hashable.
self
.
_running_models
=
[]
self
.
_population
:
Deque
[
Individual
]
=
collections
.
deque
()
self
.
_running_models
:
list
[
tuple
[
dict
,
Model
]]
=
[]
self
.
_polling_interval
=
2.
self
.
_polling_interval
=
2.
self
.
filter
=
model_filter
self
.
filter
=
model_filter
...
@@ -95,6 +105,18 @@ class RegularizedEvolution(BaseStrategy):
...
@@ -95,6 +105,18 @@ class RegularizedEvolution(BaseStrategy):
parent
=
min
(
samples
,
key
=
lambda
sample
:
sample
.
y
)
parent
=
min
(
samples
,
key
=
lambda
sample
:
sample
.
y
)
return
parent
.
x
return
parent
.
x
def
repeat_until_new_config
(
self
,
generator
):
if
not
self
.
dedup
:
# Do nothing if not deduplicating
return
generator
()
for
_
in
range
(
self
.
dedup_retries
):
config
=
generator
()
if
config
not
in
self
.
_history_configs
:
return
config
_logger
.
warning
(
'Deduplication failed. Generating an arbitrary config.'
)
return
generator
()
def
run
(
self
,
base_model
,
applied_mutators
):
def
run
(
self
,
base_model
,
applied_mutators
):
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
# Run the first population regardless concurrency
# Run the first population regardless concurrency
...
@@ -102,7 +124,7 @@ class RegularizedEvolution(BaseStrategy):
...
@@ -102,7 +124,7 @@ class RegularizedEvolution(BaseStrategy):
while
len
(
self
.
_population
)
+
len
(
self
.
_running_models
)
<=
self
.
population_size
:
while
len
(
self
.
_population
)
+
len
(
self
.
_running_models
)
<=
self
.
population_size
:
# try to submit new models
# try to submit new models
while
len
(
self
.
_population
)
+
len
(
self
.
_running_models
)
<
self
.
population_size
:
while
len
(
self
.
_population
)
+
len
(
self
.
_running_models
)
<
self
.
population_size
:
config
=
self
.
random
(
search_space
)
config
=
self
.
repeat_until_new_config
(
lambda
:
self
.
random
(
search_space
)
)
self
.
_submit_config
(
config
,
base_model
,
applied_mutators
)
self
.
_submit_config
(
config
,
base_model
,
applied_mutators
)
# collect results
# collect results
self
.
_move_succeeded_models_to_population
()
self
.
_move_succeeded_models_to_population
()
...
@@ -117,7 +139,7 @@ class RegularizedEvolution(BaseStrategy):
...
@@ -117,7 +139,7 @@ class RegularizedEvolution(BaseStrategy):
while
self
.
_success_count
+
len
(
self
.
_running_models
)
<=
self
.
cycles
:
while
self
.
_success_count
+
len
(
self
.
_running_models
)
<=
self
.
cycles
:
# try to submit new models
# try to submit new models
while
query_available_resources
()
>
0
and
self
.
_success_count
+
len
(
self
.
_running_models
)
<
self
.
cycles
:
while
query_available_resources
()
>
0
and
self
.
_success_count
+
len
(
self
.
_running_models
)
<
self
.
cycles
:
config
=
self
.
mutate
(
self
.
best_parent
(),
search_space
)
config
=
self
.
repeat_until_new_config
(
lambda
:
self
.
mutate
(
self
.
best_parent
(),
search_space
)
)
self
.
_submit_config
(
config
,
base_model
,
applied_mutators
)
self
.
_submit_config
(
config
,
base_model
,
applied_mutators
)
# collect results
# collect results
self
.
_move_succeeded_models_to_population
()
self
.
_move_succeeded_models_to_population
()
...
@@ -129,6 +151,7 @@ class RegularizedEvolution(BaseStrategy):
...
@@ -129,6 +151,7 @@ class RegularizedEvolution(BaseStrategy):
def
_submit_config
(
self
,
config
,
base_model
,
mutators
):
def
_submit_config
(
self
,
config
,
base_model
,
mutators
):
_logger
.
debug
(
'Model submitted to running queue: %s'
,
config
)
_logger
.
debug
(
'Model submitted to running queue: %s'
,
config
)
self
.
_history_configs
.
append
(
config
)
model
=
get_targeted_model
(
base_model
,
mutators
,
config
)
model
=
get_targeted_model
(
base_model
,
mutators
,
config
)
if
not
filter_model
(
self
.
filter
,
model
):
if
not
filter_model
(
self
.
filter
,
model
):
if
self
.
on_failure
==
"worst"
:
if
self
.
on_failure
==
"worst"
:
...
...
test/algo/nas/test_strategy.py
View file @
aa98462d
...
@@ -136,6 +136,13 @@ def test_evolution():
...
@@ -136,6 +136,13 @@ def test_evolution():
wait_models
(
*
engine
.
models
)
wait_models
(
*
engine
.
models
)
_reset_execution_engine
()
_reset_execution_engine
()
evolution
=
strategy
.
RegularizedEvolution
(
population_size
=
5
,
sample_size
=
3
,
cycles
=
10
,
mutation_prob
=
0.5
,
dedup
=
True
,
on_failure
=
'ignore'
)
engine
=
MockExecutionEngine
(
failure_prob
=
0.2
)
_reset_execution_engine
(
engine
)
evolution
.
run
(
*
_get_model_and_mutators
())
wait_models
(
*
engine
.
models
)
_reset_execution_engine
()
evolution
=
strategy
.
RegularizedEvolution
(
population_size
=
5
,
sample_size
=
3
,
cycles
=
10
,
mutation_prob
=
0.5
,
on_failure
=
'worst'
)
evolution
=
strategy
.
RegularizedEvolution
(
population_size
=
5
,
sample_size
=
3
,
cycles
=
10
,
mutation_prob
=
0.5
,
on_failure
=
'worst'
)
engine
=
MockExecutionEngine
(
failure_prob
=
0.4
)
engine
=
MockExecutionEngine
(
failure_prob
=
0.4
)
_reset_execution_engine
(
engine
)
_reset_execution_engine
(
engine
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment