Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d1c8d840
"...composable_kernel.git" did not exist on "c9013009a0f093df44a318e265336ebaf3c68a2f"
Unverified
Commit
d1c8d840
authored
Jun 07, 2021
by
Yuge Zhang
Committed by
GitHub
Jun 07, 2021
Browse files
Fix a few issues in Retiarii (#3725)
parent
6b52fb12
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
11 deletions
+27
-11
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+3
-3
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+6
-5
nni/retiarii/strategy/_rl_impl.py
nni/retiarii/strategy/_rl_impl.py
+4
-1
nni/retiarii/strategy/bruteforce.py
nni/retiarii/strategy/bruteforce.py
+2
-2
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+12
-0
No files found.
nni/retiarii/nn/pytorch/component.py
View file @
d1c8d840
...
...
@@ -72,7 +72,7 @@ class Repeat(nn.Module):
class
Cell
(
nn
.
Module
):
"""
Cell structure [
1]_ [2
]_ that is popularly used in NAS literature.
Cell structure [
zophnas]_ [zophnasnet
]_ that is popularly used in NAS literature.
A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from
``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell.
...
...
@@ -95,8 +95,8 @@ class Cell(nn.Module):
References
----------
.. [
1
] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578
.. [
2
] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le,
.. [
zophnas
] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578
.. [
zophnasnet
] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le,
"Learning Transferable Architectures for Scalable Image Recognition". https://arxiv.org/abs/1707.07012
"""
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
d1c8d840
...
...
@@ -15,7 +15,7 @@ from ...utils import uid
class
LayerChoiceMutator
(
Mutator
):
def
__init__
(
self
,
nodes
:
List
[
Node
]):
super
().
__init__
()
super
().
__init__
(
label
=
nodes
[
0
].
operation
.
parameters
[
'label'
]
)
self
.
nodes
=
nodes
def
mutate
(
self
,
model
):
...
...
@@ -40,7 +40,7 @@ class LayerChoiceMutator(Mutator):
class
InputChoiceMutator
(
Mutator
):
def
__init__
(
self
,
nodes
:
List
[
Node
]):
super
().
__init__
()
super
().
__init__
(
label
=
nodes
[
0
].
operation
.
parameters
[
'label'
]
)
self
.
nodes
=
nodes
def
mutate
(
self
,
model
):
...
...
@@ -56,7 +56,7 @@ class InputChoiceMutator(Mutator):
class
ValueChoiceMutator
(
Mutator
):
def
__init__
(
self
,
nodes
:
List
[
Node
],
candidates
:
List
[
Any
]):
super
().
__init__
()
super
().
__init__
(
label
=
nodes
[
0
].
operation
.
parameters
[
'label'
]
)
self
.
nodes
=
nodes
self
.
candidates
=
candidates
...
...
@@ -69,7 +69,8 @@ class ValueChoiceMutator(Mutator):
class
ParameterChoiceMutator
(
Mutator
):
def
__init__
(
self
,
nodes
:
List
[
Tuple
[
Node
,
str
]],
candidates
:
List
[
Any
]):
super
().
__init__
()
node
,
argname
=
nodes
[
0
]
super
().
__init__
(
label
=
node
.
operation
.
parameters
[
argname
].
label
)
self
.
nodes
=
nodes
self
.
candidates
=
candidates
...
...
@@ -84,7 +85,7 @@ class ParameterChoiceMutator(Mutator):
class
RepeatMutator
(
Mutator
):
def
__init__
(
self
,
nodes
:
List
[
Node
]):
# nodes is a subgraph consisting of repeated blocks.
super
().
__init__
()
super
().
__init__
(
label
=
nodes
[
0
].
operation
.
parameters
[
'label'
]
)
self
.
nodes
=
nodes
def
_retrieve_chain_from_graph
(
self
,
graph
:
Graph
)
->
List
[
Node
]:
...
...
nni/retiarii/strategy/_rl_impl.py
View file @
d1c8d840
# This file might cause import error for those who didn't install RL-related dependencies
import
logging
import
threading
from
multiprocessing.pool
import
ThreadPool
import
gym
...
...
@@ -18,6 +19,7 @@ from ..execution import submit_models, wait_models
_logger
=
logging
.
getLogger
(
__name__
)
_thread_lock
=
threading
.
Lock
()
class
MultiThreadEnvWorker
(
EnvWorker
):
...
...
@@ -100,7 +102,8 @@ class ModelEvaluationEnv(gym.Env):
if
self
.
cur_step
<
self
.
num_steps
else
self
.
action_dim
}
if
self
.
cur_step
==
self
.
num_steps
:
model
=
get_targeted_model
(
self
.
base_model
,
self
.
mutators
,
self
.
sample
)
with
_thread_lock
:
model
=
get_targeted_model
(
self
.
base_model
,
self
.
mutators
,
self
.
sample
)
_logger
.
info
(
f
'New model created:
{
self
.
sample
}
'
)
submit_models
(
model
)
wait_models
(
model
)
...
...
nni/retiarii/strategy/bruteforce.py
View file @
d1c8d840
...
...
@@ -62,7 +62,7 @@ class GridSearch(BaseStrategy):
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
for
sample
in
grid_generator
(
search_space
,
shuffle
=
self
.
shuffle
):
_logger
.
debug
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
if
query_available_resources
()
<=
0
:
while
query_available_resources
()
<=
0
:
time
.
sleep
(
self
.
_polling_interval
)
submit_models
(
get_targeted_model
(
base_model
,
applied_mutators
,
sample
))
...
...
@@ -113,6 +113,6 @@ class Random(BaseStrategy):
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
for
sample
in
random_generator
(
search_space
,
dedup
=
self
.
dedup
):
_logger
.
debug
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
if
query_available_resources
()
<=
0
:
while
query_available_resources
()
<=
0
:
time
.
sleep
(
self
.
_polling_interval
)
submit_models
(
get_targeted_model
(
base_model
,
applied_mutators
,
sample
))
nni/retiarii/strategy/tpe_strategy.py
View file @
d1c8d840
...
...
@@ -40,6 +40,18 @@ class TPESampler(Sampler):
class
TPEStrategy
(
BaseStrategy
):
"""
The Tree-structured Parzen Estimator (TPE) [bergstrahpo]_ is a sequential model-based optimization (SMBO) approach.
SMBO methods sequentially construct models to approximate the performance of hyperparameters based on historical measurements,
and then subsequently choose new hyperparameters to test based on this model.
References
----------
.. [bergstrahpo] Bergstra et al., "Algorithms for Hyper-Parameter Optimization".
https://papers.nips.cc/paper/4443-algorithms-for-hyper-parameter-optimization.pdf
"""
def
__init__
(
self
):
self
.
tpe_sampler
=
TPESampler
()
self
.
model_id
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment