Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
08af7771
"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "cc9efa4e3b76a12829d51260b60dde089becd1c7"
Unverified
Commit
08af7771
authored
Dec 28, 2020
by
QuanluZhang
Committed by
GitHub
Dec 28, 2020
Browse files
[2.0a2] [retiarii] improvement (#3208)
parent
c444e862
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
371 additions
and
156 deletions
+371
-156
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+1
-0
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+18
-15
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+21
-16
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+1
-1
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+2
-9
nni/retiarii/execution/listener.py
nni/retiarii/execution/listener.py
+0
-11
nni/retiarii/experiment.py
nni/retiarii/experiment.py
+48
-36
nni/retiarii/integration.py
nni/retiarii/integration.py
+14
-32
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+36
-0
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+7
-4
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+52
-17
nni/retiarii/strategies/__init__.py
nni/retiarii/strategies/__init__.py
+1
-0
nni/retiarii/strategies/random_strategy.py
nni/retiarii/strategies/random_strategy.py
+32
-0
nni/retiarii/strategies/tpe_strategy.py
nni/retiarii/strategies/tpe_strategy.py
+21
-9
nni/retiarii/trainer/pytorch/utils.py
nni/retiarii/trainer/pytorch/utils.py
+3
-2
nni/retiarii/utils.py
nni/retiarii/utils.py
+6
-1
test/retiarii_test/darts/darts_model.py
test/retiarii_test/darts/darts_model.py
+1
-1
test/retiarii_test/darts/test.py
test/retiarii_test/darts/test.py
+3
-2
test/retiarii_test/darts/test_oneshot.py
test/retiarii_test/darts/test_oneshot.py
+104
-0
No files found.
nni/retiarii/converter/graph_gen.py
View file @
08af7771
...
@@ -432,6 +432,7 @@ def _handle_layerchoice(module):
...
@@ -432,6 +432,7 @@ def _handle_layerchoice(module):
def
_handle_inputchoice
(
module
):
def
_handle_inputchoice
(
module
):
m_attrs
=
{}
m_attrs
=
{}
m_attrs
[
'n_candidates'
]
=
module
.
n_candidates
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
m_attrs
[
'reduction'
]
=
module
.
reduction
m_attrs
[
'reduction'
]
=
module
.
reduction
m_attrs
[
'label'
]
=
module
.
label
m_attrs
[
'label'
]
=
module
.
label
...
...
nni/retiarii/execution/api.py
View file @
08af7771
import
time
import
time
import
os
from
typing
import
List
from
..graph
import
Model
,
ModelStatus
from
..graph
import
Model
,
ModelStatus
from
.base
import
BaseExecutionEngine
from
.interface
import
AbstractExecutionEngine
from
.cgo_engine
import
CGOExecutionEngine
from
.interface
import
AbstractExecutionEngine
,
WorkerInfo
from
.listener
import
DefaultListener
from
.listener
import
DefaultListener
_execution_engine
=
None
_execution_engine
=
None
_default_listener
=
None
_default_listener
=
None
__all__
=
[
'get_execution_engine'
,
'get_and_register_default_listener'
,
__all__
=
[
'get_execution_engine'
,
'get_and_register_default_listener'
,
'submit_models'
,
'wait_models'
,
'query_available_resources'
]
'submit_models'
,
'wait_models'
,
'query_available_resources'
,
'set_execution_engine'
,
'is_stopped_exec'
]
def
set_execution_engine
(
engine
)
->
None
:
global
_execution_engine
if
_execution_engine
is
None
:
_execution_engine
=
engine
else
:
raise
RuntimeError
(
'execution engine is already set'
)
def
get_execution_engine
()
->
Base
ExecutionEngine
:
def
get_execution_engine
()
->
Abstract
ExecutionEngine
:
"""
"""
Currently we assume the default execution engine is BaseExecutionEngine.
Currently we assume the default execution engine is BaseExecutionEngine.
"""
"""
global
_execution_engine
global
_execution_engine
if
_execution_engine
is
None
:
if
os
.
environ
.
get
(
'CGO'
)
==
'true'
:
_execution_engine
=
CGOExecutionEngine
()
else
:
_execution_engine
=
BaseExecutionEngine
()
return
_execution_engine
return
_execution_engine
...
@@ -51,6 +49,11 @@ def wait_models(*models: Model) -> None:
...
@@ -51,6 +49,11 @@ def wait_models(*models: Model) -> None:
break
break
def
query_available_resources
()
->
List
[
WorkerInfo
]:
def
query_available_resources
()
->
int
:
listener
=
get_and_register_default_listener
(
get_execution_engine
())
engine
=
get_execution_engine
()
return
listener
.
resources
resources
=
engine
.
query_available_resource
()
return
resources
if
isinstance
(
resources
,
int
)
else
len
(
resources
)
def
is_stopped_exec
(
model
:
Model
)
->
bool
:
return
model
.
status
in
(
ModelStatus
.
Trained
,
ModelStatus
.
Failed
)
nni/retiarii/execution/base.py
View file @
08af7771
import
logging
import
logging
import
os
import
random
import
string
from
typing
import
Dict
,
Any
,
List
from
typing
import
Dict
,
Any
,
List
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
from
..
import
codegen
,
utils
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
from
..graph
import
Model
,
ModelStatus
,
MetricData
from
..integration
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
..integration
_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -29,7 +32,7 @@ class BaseGraphData:
...
@@ -29,7 +32,7 @@ class BaseGraphData:
class
BaseExecutionEngine
(
AbstractExecutionEngine
):
class
BaseExecutionEngine
(
AbstractExecutionEngine
):
"""
"""
The execution engine with no optimization at all.
The execution engine with no optimization at all.
Resource management is
yet to be
implemented.
Resource management is implemented
in this class
.
"""
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
...
@@ -50,6 +53,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -50,6 +53,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
resources
=
0
def
submit_models
(
self
,
*
models
:
Model
)
->
None
:
def
submit_models
(
self
,
*
models
:
Model
)
->
None
:
for
model
in
models
:
for
model
in
models
:
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
...
@@ -60,17 +65,14 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -60,17 +65,14 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self
.
_listeners
.
append
(
listener
)
self
.
_listeners
.
append
(
listener
)
def
_send_trial_callback
(
self
,
paramater
:
dict
)
->
None
:
def
_send_trial_callback
(
self
,
paramater
:
dict
)
->
None
:
for
listener
in
self
.
_listeners
:
if
self
.
resources
<=
0
:
_logger
.
warning
(
'resources: %s'
,
listener
.
resources
)
if
not
listener
.
has_available_resource
():
_logger
.
warning
(
'There is no available resource, but trial is submitted.'
)
_logger
.
warning
(
'There is no available resource, but trial is submitted.'
)
listener
.
on_resource_used
(
1
)
self
.
resources
-=
1
_logger
.
warning
(
'on_resource_used: %
s
'
,
listener
.
resources
)
_logger
.
info
(
'on_resource_used: %
d
'
,
self
.
resources
)
def
_request_trial_jobs_callback
(
self
,
num_trials
:
int
)
->
None
:
def
_request_trial_jobs_callback
(
self
,
num_trials
:
int
)
->
None
:
for
listener
in
self
.
_listeners
:
self
.
resources
+=
num_trials
listener
.
on_resource_available
(
1
*
num_trials
)
_logger
.
info
(
'on_resource_available: %d'
,
self
.
resources
)
_logger
.
warning
(
'on_resource_available: %s'
,
listener
.
resources
)
def
_trial_end_callback
(
self
,
trial_id
:
int
,
success
:
bool
)
->
None
:
def
_trial_end_callback
(
self
,
trial_id
:
int
,
success
:
bool
)
->
None
:
model
=
self
.
_running_models
[
trial_id
]
model
=
self
.
_running_models
[
trial_id
]
...
@@ -93,8 +95,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -93,8 +95,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
for
listener
in
self
.
_listeners
:
for
listener
in
self
.
_listeners
:
listener
.
on_metric
(
model
,
metrics
)
listener
.
on_metric
(
model
,
metrics
)
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]
:
def
query_available_resource
(
self
)
->
int
:
r
aise
NotImplementedError
# move the method from listener to here?
r
eturn
self
.
resources
@
classmethod
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
def
trial_execute_graph
(
cls
)
->
None
:
...
@@ -102,9 +104,12 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -102,9 +104,12 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Initialize the model, hand it over to trainer.
Initialize the model, hand it over to trainer.
"""
"""
graph_data
=
BaseGraphData
.
load
(
receive_trial_parameters
())
graph_data
=
BaseGraphData
.
load
(
receive_trial_parameters
())
with
open
(
'_generated_model.py'
,
'w'
)
as
f
:
random_str
=
''
.
join
(
random
.
choice
(
string
.
ascii_uppercase
+
string
.
digits
)
for
_
in
range
(
6
))
file_name
=
f
'_generated_model_
{
random_str
}
.py'
with
open
(
file_name
,
'w'
)
as
f
:
f
.
write
(
graph_data
.
model_script
)
f
.
write
(
graph_data
.
model_script
)
trainer_cls
=
utils
.
import_
(
graph_data
.
training_module
)
trainer_cls
=
utils
.
import_
(
graph_data
.
training_module
)
model_cls
=
utils
.
import_
(
'_generated_model._model'
)
model_cls
=
utils
.
import_
(
f
'_generated_model
_
{
random_str
}
._model'
)
trainer_instance
=
trainer_cls
(
model
=
model_cls
(),
**
graph_data
.
training_kwargs
)
trainer_instance
=
trainer_cls
(
model
=
model_cls
(),
**
graph_data
.
training_kwargs
)
trainer_instance
.
fit
()
trainer_instance
.
fit
()
os
.
remove
(
file_name
)
\ No newline at end of file
nni/retiarii/execution/cgo_engine.py
View file @
08af7771
...
@@ -4,7 +4,7 @@ from typing import List, Dict, Tuple
...
@@ -4,7 +4,7 @@ from typing import List, Dict, Tuple
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
from
..graph
import
Model
,
ModelStatus
,
MetricData
from
..integration
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
..integration
_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
.logical_optimizer.logical_plan
import
LogicalPlan
,
PhysicalDevice
from
.logical_optimizer.logical_plan
import
LogicalPlan
,
PhysicalDevice
from
.logical_optimizer.opt_dedup_input
import
DedupInputOptimizer
from
.logical_optimizer.opt_dedup_input
import
DedupInputOptimizer
...
...
nni/retiarii/execution/interface.py
View file @
08af7771
from
abc
import
ABC
,
abstractmethod
,
abstractclassmethod
from
abc
import
ABC
,
abstractmethod
,
abstractclassmethod
from
typing
import
Any
,
NewType
,
List
from
typing
import
Any
,
NewType
,
List
,
Union
from
..graph
import
Model
,
MetricData
from
..graph
import
Model
,
MetricData
...
@@ -59,13 +59,6 @@ class AbstractGraphListener(ABC):
...
@@ -59,13 +59,6 @@ class AbstractGraphListener(ABC):
"""
"""
pass
pass
@
abstractmethod
def
on_resource_available
(
self
,
resources
:
List
[
WorkerInfo
])
->
None
:
"""
Reports when a worker becomes idle.
"""
pass
class
AbstractExecutionEngine
(
ABC
):
class
AbstractExecutionEngine
(
ABC
):
"""
"""
...
@@ -109,7 +102,7 @@ class AbstractExecutionEngine(ABC):
...
@@ -109,7 +102,7 @@ class AbstractExecutionEngine(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
]
,
int
]
:
"""
"""
Returns information of all idle workers.
Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
...
...
nni/retiarii/execution/listener.py
View file @
08af7771
...
@@ -3,11 +3,6 @@ from .interface import MetricData, AbstractGraphListener
...
@@ -3,11 +3,6 @@ from .interface import MetricData, AbstractGraphListener
class
DefaultListener
(
AbstractGraphListener
):
class
DefaultListener
(
AbstractGraphListener
):
def
__init__
(
self
):
self
.
resources
:
int
=
0
# simply resource count
def
has_available_resource
(
self
)
->
bool
:
return
self
.
resources
>
0
def
on_metric
(
self
,
model
:
Model
,
metric
:
MetricData
)
->
None
:
def
on_metric
(
self
,
model
:
Model
,
metric
:
MetricData
)
->
None
:
model
.
metric
=
metric
model
.
metric
=
metric
...
@@ -20,9 +15,3 @@ class DefaultListener(AbstractGraphListener):
...
@@ -20,9 +15,3 @@ class DefaultListener(AbstractGraphListener):
model
.
status
=
ModelStatus
.
Trained
model
.
status
=
ModelStatus
.
Trained
else
:
else
:
model
.
status
=
ModelStatus
.
Failed
model
.
status
=
ModelStatus
.
Failed
def
on_resource_available
(
self
,
resources
:
int
)
->
None
:
self
.
resources
+=
resources
def
on_resource_used
(
self
,
resources
:
int
)
->
None
:
self
.
resources
-=
resources
nni/retiarii/experiment.py
View file @
08af7771
import
atexit
import
logging
import
logging
import
time
import
socket
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -7,10 +8,14 @@ from subprocess import Popen
...
@@ -7,10 +8,14 @@ from subprocess import Popen
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
from
..experiment
import
Experiment
,
TrainingServiceConfig
,
launcher
,
rest
import
colorama
import
psutil
from
..experiment
import
Experiment
,
TrainingServiceConfig
,
launcher
from
..experiment.config.base
import
ConfigBase
,
PathLike
from
..experiment.config.base
import
ConfigBase
,
PathLike
from
..experiment.config
import
util
from
..experiment.config
import
util
from
..experiment.pipe
import
Pipe
from
..experiment.pipe
import
Pipe
from
.graph
import
Model
from
.graph
import
Model
from
.utils
import
get_records
from
.utils
import
get_records
from
.integration
import
RetiariiAdvisor
from
.integration
import
RetiariiAdvisor
...
@@ -18,9 +23,11 @@ from .converter import convert_to_graph
...
@@ -18,9 +23,11 @@ from .converter import convert_to_graph
from
.mutator
import
Mutator
,
LayerChoiceMutator
,
InputChoiceMutator
from
.mutator
import
Mutator
,
LayerChoiceMutator
,
InputChoiceMutator
from
.trainer.interface
import
BaseTrainer
from
.trainer.interface
import
BaseTrainer
from
.strategies.strategy
import
BaseStrategy
from
.strategies.strategy
import
BaseStrategy
from
.trainer.pytorch
import
DartsTrainer
,
EnasTrainer
,
ProxylessTrainer
,
RandomTrainer
,
SinglePathTrainer
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
OneShotTrainers
=
(
DartsTrainer
,
EnasTrainer
,
ProxylessTrainer
,
RandomTrainer
,
SinglePathTrainer
)
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ConfigBase
):
class
RetiariiExeConfig
(
ConfigBase
):
...
@@ -76,7 +83,7 @@ _validation_rules = {
...
@@ -76,7 +83,7 @@ _validation_rules = {
class
RetiariiExperiment
(
Experiment
):
class
RetiariiExperiment
(
Experiment
):
def
__init__
(
self
,
base_model
:
Model
,
trainer
:
BaseTrainer
,
def
__init__
(
self
,
base_model
:
Model
,
trainer
:
BaseTrainer
,
applied_mutators
:
Mutator
,
strategy
:
BaseStrategy
):
applied_mutators
:
Mutator
=
None
,
strategy
:
BaseStrategy
=
None
):
self
.
config
:
RetiariiExeConfig
=
None
self
.
config
:
RetiariiExeConfig
=
None
self
.
port
:
Optional
[
int
]
=
None
self
.
port
:
Optional
[
int
]
=
None
...
@@ -87,6 +94,7 @@ class RetiariiExperiment(Experiment):
...
@@ -87,6 +94,7 @@ class RetiariiExperiment(Experiment):
self
.
recorded_module_args
=
get_records
()
self
.
recorded_module_args
=
get_records
()
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
...
@@ -103,7 +111,10 @@ class RetiariiExperiment(Experiment):
...
@@ -103,7 +111,10 @@ class RetiariiExperiment(Experiment):
mutator
=
LayerChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'choices'
])
mutator
=
LayerChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'choices'
])
applied_mutators
.
append
(
mutator
)
applied_mutators
.
append
(
mutator
)
for
node
in
ic_nodes
:
for
node
in
ic_nodes
:
mutator
=
InputChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'n_chosen'
])
mutator
=
InputChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'n_candidates'
],
node
.
operation
.
parameters
[
'n_chosen'
],
node
.
operation
.
parameters
[
'reduction'
])
applied_mutators
.
append
(
mutator
)
applied_mutators
.
append
(
mutator
)
return
applied_mutators
return
applied_mutators
...
@@ -132,7 +143,7 @@ class RetiariiExperiment(Experiment):
...
@@ -132,7 +143,7 @@ class RetiariiExperiment(Experiment):
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
,
self
.
applied_mutators
)).
start
()
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
,
self
.
applied_mutators
)).
start
()
_logger
.
info
(
'Strategy started!'
)
_logger
.
info
(
'Strategy started!'
)
def
start
(
self
,
config
:
RetiariiExeConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
"""
Start the experiment in background.
Start the experiment in background.
This method will raise exception on failure.
This method will raise exception on failure.
...
@@ -144,11 +155,12 @@ class RetiariiExperiment(Experiment):
...
@@ -144,11 +155,12 @@ class RetiariiExperiment(Experiment):
debug
debug
Whether to start in debug mode.
Whether to start in debug mode.
"""
"""
# FIXME:
atexit
.
register
(
self
.
stop
)
if
debug
:
if
debug
:
logging
.
getLogger
(
'nni'
).
setLevel
(
logging
.
DEBUG
)
logging
.
getLogger
(
'nni'
).
setLevel
(
logging
.
DEBUG
)
self
.
_proc
,
self
.
_pipe
=
launcher
.
start_experiment
(
config
,
port
,
debug
)
self
.
_proc
,
self
.
_pipe
=
launcher
.
start_experiment
(
self
.
config
,
port
,
debug
)
assert
self
.
_proc
is
not
None
assert
self
.
_proc
is
not
None
assert
self
.
_pipe
is
not
None
assert
self
.
_pipe
is
not
None
...
@@ -156,42 +168,42 @@ class RetiariiExperiment(Experiment):
...
@@ -156,42 +168,42 @@ class RetiariiExperiment(Experiment):
# dispatcher must be created after pipe initialized
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread
(
target
=
self
.
_dispatcher
.
run
).
start
()
self
.
_dispatcher_thread
=
Thread
(
target
=
self
.
_dispatcher
.
run
)
self
.
_dispatcher_thread
.
start
()
self
.
_start_strategy
()
self
.
_start_strategy
()
ips
=
[
self
.
config
.
nni_manager_ip
]
for
interfaces
in
psutil
.
net_if_addrs
().
values
():
for
interface
in
interfaces
:
if
interface
.
family
==
socket
.
AF_INET
:
ips
.
append
(
interface
.
address
)
ips
=
[
f
'http://
{
ip
}
:
{
port
}
'
for
ip
in
ips
if
ip
]
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
_logger
.
info
(
msg
)
# TODO: register experiment management metadata
# TODO: register experiment management metadata
def
stop
(
self
)
->
None
:
def
run
(
self
,
config
:
RetiariiExeConfig
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
"""
"""
Stop background experiment.
Run the experiment.
This function will block until experiment finish or error.
"""
"""
self
.
_proc
.
kill
()
if
isinstance
(
self
.
trainer
,
OneShotTrainers
):
self
.
_pipe
.
close
()
self
.
trainer
.
fit
()
else
:
assert
config
is
not
None
,
'You are using classic search mode, config cannot be None!'
self
.
config
=
config
super
().
run
(
port
,
debug
)
self
.
port
=
None
def
export_top_models
(
self
,
top_n
:
int
):
self
.
_proc
=
None
"""
self
.
_pipe
=
None
export several top performing models
"""
raise
NotImplementedError
def
r
un
(
self
,
config
:
RetiariiExeConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
def
r
etrain_model
(
self
,
model
)
:
"""
"""
Run the experiment.
this function retrains the exported model, and test it to output test accuracy
This function will block until experiment finish or error.
"""
"""
self
.
config
=
config
raise
NotImplementedError
self
.
start
(
config
,
port
,
debug
)
try
:
while
True
:
time
.
sleep
(
10
)
status
=
self
.
get_status
()
# TODO: double check the status
if
status
in
[
'ERROR'
,
'STOPPED'
,
'NO_MORE_TRIAL'
]:
return
status
finally
:
self
.
stop
()
def
get_status
(
self
)
->
str
:
if
self
.
port
is
None
:
raise
RuntimeError
(
'Experiment is not running'
)
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
)
return
resp
[
'status'
]
nni/retiarii/integration.py
View file @
08af7771
import
logging
import
logging
import
os
from
typing
import
Any
,
Callable
from
typing
import
Any
,
Callable
import
json_tricks
import
json_tricks
import
nni
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.utils
import
MetricType
from
nni.utils
import
MetricType
from
.graph
import
MetricData
from
.graph
import
MetricData
from
.execution.base
import
BaseExecutionEngine
from
.execution.cgo_engine
import
CGOExecutionEngine
from
.execution.api
import
set_execution_engine
from
.integration_api
import
register_advisor
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -55,6 +59,15 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -55,6 +59,15 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
parameters_count
=
0
self
.
parameters_count
=
0
engine
=
self
.
_create_execution_engine
()
set_execution_engine
(
engine
)
def
_create_execution_engine
(
self
):
if
os
.
environ
.
get
(
'CGO'
)
==
'true'
:
return
CGOExecutionEngine
()
else
:
return
BaseExecutionEngine
()
def
handle_initialize
(
self
,
data
):
def
handle_initialize
(
self
,
data
):
"""callback for initializing the advisor
"""callback for initializing the advisor
Parameters
Parameters
...
@@ -126,34 +139,3 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -126,34 +139,3 @@ class RetiariiAdvisor(MsgDispatcherBase):
else
:
else
:
return
value
return
value
return
value
return
value
_advisor
:
RetiariiAdvisor
=
None
def
get_advisor
()
->
RetiariiAdvisor
:
global
_advisor
assert
_advisor
is
not
None
return
_advisor
def
register_advisor
(
advisor
:
RetiariiAdvisor
):
global
_advisor
assert
_advisor
is
None
_advisor
=
advisor
def
send_trial
(
parameters
:
dict
)
->
int
:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return
get_advisor
().
send_trial
(
parameters
)
def
receive_trial_parameters
()
->
dict
:
"""
Received a new trial. Executed on trial end.
"""
params
=
nni
.
get_next_parameter
()
return
params
nni/retiarii/integration_api.py
0 → 100644
View file @
08af7771
from
typing
import
NewType
,
Any
import
nni
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor
=
NewType
(
'RetiariiAdvisor'
,
Any
)
_advisor
:
'RetiariiAdvisor'
=
None
def
get_advisor
()
->
'RetiariiAdvisor'
:
global
_advisor
assert
_advisor
is
not
None
return
_advisor
def
register_advisor
(
advisor
:
'RetiariiAdvisor'
):
global
_advisor
assert
_advisor
is
None
_advisor
=
advisor
def
send_trial
(
parameters
:
dict
)
->
int
:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return
get_advisor
().
send_trial
(
parameters
)
def
receive_trial_parameters
()
->
dict
:
"""
Received a new trial. Executed on trial end.
"""
params
=
nni
.
get_next_parameter
()
return
params
nni/retiarii/mutator.py
View file @
08af7771
...
@@ -104,6 +104,7 @@ class _RecorderSampler(Sampler):
...
@@ -104,6 +104,7 @@ class _RecorderSampler(Sampler):
self
.
recorded_candidates
.
append
(
candidates
)
self
.
recorded_candidates
.
append
(
candidates
)
return
candidates
[
0
]
return
candidates
[
0
]
# the following is for inline mutation
# the following is for inline mutation
...
@@ -122,14 +123,16 @@ class LayerChoiceMutator(Mutator):
...
@@ -122,14 +123,16 @@ class LayerChoiceMutator(Mutator):
class
InputChoiceMutator
(
Mutator
):
class
InputChoiceMutator
(
Mutator
):
def
__init__
(
self
,
node_name
:
str
,
n_c
hosen
:
int
):
def
__init__
(
self
,
node_name
:
str
,
n_c
andidates
:
int
,
n_chosen
:
int
,
reduction
:
str
):
super
().
__init__
()
super
().
__init__
()
self
.
node_name
=
node_name
self
.
node_name
=
node_name
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
def
mutate
(
self
,
model
):
def
mutate
(
self
,
model
):
target
=
model
.
get_node_by_name
(
self
.
node_name
)
target
=
model
.
get_node_by_name
(
self
.
node_name
)
candidates
=
[
i
for
i
in
range
(
self
.
n_c
hosen
)]
candidates
=
[
i
for
i
in
range
(
self
.
n_c
andidates
)]
chosen
=
self
.
choice
(
candidates
)
chosen
=
[
self
.
choice
(
candidates
)
for
_
in
range
(
self
.
n_chosen
)]
target
.
update_operation
(
'__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs'
,
target
.
update_operation
(
'__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs'
,
{
'chosen'
:
chosen
})
{
'chosen'
:
chosen
,
'reduction'
:
self
.
reduction
})
nni/retiarii/nn/pytorch/nn.py
View file @
08af7771
...
@@ -5,10 +5,12 @@ from typing import Any, List
...
@@ -5,10 +5,12 @@ from typing import Any, List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
...utils
import
add_record
from
...utils
import
add_record
,
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
# NOTE: support pytorch version >= 1.5.0
__all__
=
[
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'Placeholder'
,
'LayerChoice'
,
'InputChoice'
,
'Placeholder'
,
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
...
@@ -29,18 +31,27 @@ __all__ = [
...
@@ -29,18 +31,27 @@ __all__ = [
'ConstantPad3d'
,
'Bilinear'
,
'CosineSimilarity'
,
'Unfold'
,
'Fold'
,
'ConstantPad3d'
,
'Bilinear'
,
'CosineSimilarity'
,
'Unfold'
,
'Fold'
,
'AdaptiveLogSoftmaxWithLoss'
,
'TransformerEncoder'
,
'TransformerDecoder'
,
'AdaptiveLogSoftmaxWithLoss'
,
'TransformerEncoder'
,
'TransformerDecoder'
,
'TransformerEncoderLayer'
,
'TransformerDecoderLayer'
,
'Transformer'
,
'TransformerEncoderLayer'
,
'TransformerDecoderLayer'
,
'Transformer'
,
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
'Flatten'
,
'Hardsigmoid'
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten'
,
'Hardsigmoid'
,
'Hardswish'
]
]
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
__all__
.
append
(
'Hardswish'
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
__all__
.
extend
([
'Unflatten'
,
'SiLU'
,
'TripletMarginWithDistanceLoss'
])
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'ChannelShuffle'
class
LayerChoice
(
nn
.
Module
):
class
LayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
op_candidates
,
reduction
=
None
,
return_mask
=
False
,
key
=
None
):
def
__init__
(
self
,
op_candidates
,
reduction
=
None
,
return_mask
=
False
,
key
=
None
):
super
(
LayerChoice
,
self
).
__init__
()
super
(
LayerChoice
,
self
).
__init__
()
self
.
candidate_ops
=
op_candidates
self
.
candidate_ops
=
op_candidates
self
.
label
=
key
self
.
label
=
key
self
.
key
=
key
# deprecated, for backward compatibility
for
i
,
module
in
enumerate
(
op_candidates
):
# deprecated, for backward compatibility
self
.
add_module
(
str
(
i
),
module
)
if
reduction
or
return_mask
:
if
reduction
or
return_mask
:
_logger
.
warning
(
'input arguments `reduction` and `return_mask` are deprecated!'
)
_logger
.
warning
(
'input arguments `reduction` and `return_mask` are deprecated!'
)
...
@@ -52,10 +63,12 @@ class InputChoice(nn.Module):
...
@@ -52,10 +63,12 @@ class InputChoice(nn.Module):
def
__init__
(
self
,
n_candidates
=
None
,
choose_from
=
None
,
n_chosen
=
1
,
def
__init__
(
self
,
n_candidates
=
None
,
choose_from
=
None
,
n_chosen
=
1
,
reduction
=
"sum"
,
return_mask
=
False
,
key
=
None
):
reduction
=
"sum"
,
return_mask
=
False
,
key
=
None
):
super
(
InputChoice
,
self
).
__init__
()
super
(
InputChoice
,
self
).
__init__
()
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
label
=
key
self
.
label
=
key
if
n_candidates
or
choose_from
or
return_mask
:
self
.
key
=
key
# deprecated, for backward compatibility
if
choose_from
or
return_mask
:
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
def
forward
(
self
,
candidate_inputs
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
def
forward
(
self
,
candidate_inputs
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
...
@@ -86,13 +99,31 @@ class Placeholder(nn.Module):
...
@@ -86,13 +99,31 @@ class Placeholder(nn.Module):
class
ChosenInputs
(
nn
.
Module
):
class
ChosenInputs
(
nn
.
Module
):
def
__init__
(
self
,
chosen
:
int
):
"""
"""
def
__init__
(
self
,
chosen
:
List
[
int
],
reduction
:
str
):
super
().
__init__
()
super
().
__init__
()
self
.
chosen
=
chosen
self
.
chosen
=
chosen
self
.
reduction
=
reduction
def
forward
(
self
,
candidate_inputs
):
def
forward
(
self
,
candidate_inputs
):
# TODO: support multiple chosen inputs
return
self
.
_tensor_reduction
(
self
.
reduction
,
[
candidate_inputs
[
i
]
for
i
in
self
.
chosen
])
return
candidate_inputs
[
self
.
chosen
]
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
reduction_type
==
"none"
:
return
tensor_list
if
not
tensor_list
:
return
None
# empty. return None for now
if
len
(
tensor_list
)
==
1
:
return
tensor_list
[
0
]
if
reduction_type
==
"sum"
:
return
sum
(
tensor_list
)
if
reduction_type
==
"mean"
:
return
sum
(
tensor_list
)
/
len
(
tensor_list
)
if
reduction_type
==
"concat"
:
return
torch
.
cat
(
tensor_list
,
dim
=
1
)
raise
ValueError
(
"Unrecognized reduction policy:
\"
{}
\"
"
.
format
(
reduction_type
))
# the following are pytorch modules
# the following are pytorch modules
...
@@ -132,7 +163,6 @@ def wrap_module(original_class):
...
@@ -132,7 +163,6 @@ def wrap_module(original_class):
return
original_class
return
original_class
# TODO: support different versions of pytorch
Identity
=
wrap_module
(
nn
.
Identity
)
Identity
=
wrap_module
(
nn
.
Identity
)
Linear
=
wrap_module
(
nn
.
Linear
)
Linear
=
wrap_module
(
nn
.
Linear
)
Conv1d
=
wrap_module
(
nn
.
Conv1d
)
Conv1d
=
wrap_module
(
nn
.
Conv1d
)
...
@@ -236,6 +266,17 @@ TransformerDecoder = wrap_module(nn.TransformerDecoder)
...
@@ -236,6 +266,17 @@ TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer
=
wrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerEncoderLayer
=
wrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
wrap_module
(
nn
.
TransformerDecoderLayer
)
TransformerDecoderLayer
=
wrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
wrap_module
(
nn
.
Transformer
)
Transformer
=
wrap_module
(
nn
.
Transformer
)
Flatten
=
wrap_module
(
nn
.
Flatten
)
Hardsigmoid
=
wrap_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
wrap_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
wrap_module
(
nn
.
SiLU
)
Unflatten
=
wrap_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
wrap_module
(
nn
.
TripletMarginWithDistanceLoss
)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
...
@@ -243,10 +284,4 @@ Transformer = wrap_module(nn.Transformer)
...
@@ -243,10 +284,4 @@ Transformer = wrap_module(nn.Transformer)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
Flatten
=
wrap_module
(
nn
.
Flatten
)
#Unflatten = wrap_module(nn.Unflatten)
Hardsigmoid
=
wrap_module
(
nn
.
Hardsigmoid
)
Hardswish
=
wrap_module
(
nn
.
Hardswish
)
#SiLU = wrap_module(nn.SiLU)
#TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
\ No newline at end of file
nni/retiarii/strategies/__init__.py
View file @
08af7771
from
.tpe_strategy
import
TPEStrategy
from
.tpe_strategy
import
TPEStrategy
from
.random_strategy
import
RandomStrategy
nni/retiarii/strategies/random_strategy.py
0 → 100644
View file @
08af7771
import
logging
import
random
import
time
from
..
import
Sampler
,
submit_models
,
query_available_resources
from
.strategy
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
class
RandomSampler
(
Sampler
):
def
choice
(
self
,
candidates
,
mutator
,
model
,
index
):
return
random
.
choice
(
candidates
)
class
RandomStrategy
(
BaseStrategy
):
def
__init__
(
self
):
self
.
random_sampler
=
RandomSampler
()
def
run
(
self
,
base_model
,
applied_mutators
):
_logger
.
info
(
'stargety start...'
)
while
True
:
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
model
=
base_model
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'mutators: %s'
,
str
(
applied_mutators
))
for
mutator
in
applied_mutators
:
mutator
.
bind_sampler
(
self
.
random_sampler
)
model
=
mutator
.
apply
(
model
)
# run models
submit_models
(
model
)
else
:
time
.
sleep
(
2
)
nni/retiarii/strategies/tpe_strategy.py
View file @
08af7771
import
logging
import
logging
import
time
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
..
import
Sampler
,
submit_models
,
wait_models
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
from
.strategy
import
BaseStrategy
from
.strategy
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -39,6 +40,7 @@ class TPEStrategy(BaseStrategy):
...
@@ -39,6 +40,7 @@ class TPEStrategy(BaseStrategy):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
tpe_sampler
=
TPESampler
()
self
.
tpe_sampler
=
TPESampler
()
self
.
model_id
=
0
self
.
model_id
=
0
self
.
running_models
=
{}
def
run
(
self
,
base_model
,
applied_mutators
):
def
run
(
self
,
base_model
,
applied_mutators
):
sample_space
=
[]
sample_space
=
[]
...
@@ -48,9 +50,10 @@ class TPEStrategy(BaseStrategy):
...
@@ -48,9 +50,10 @@ class TPEStrategy(BaseStrategy):
sample_space
.
extend
(
recorded_candidates
)
sample_space
.
extend
(
recorded_candidates
)
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
try
:
_logger
.
info
(
'stargety start...'
)
_logger
.
info
(
'stargety start...'
)
while
True
:
while
True
:
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
model
=
base_model
model
=
base_model
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'mutators: %s'
,
str
(
applied_mutators
))
_logger
.
info
(
'mutators: %s'
,
str
(
applied_mutators
))
...
@@ -61,9 +64,18 @@ class TPEStrategy(BaseStrategy):
...
@@ -61,9 +64,18 @@ class TPEStrategy(BaseStrategy):
model
=
mutator
.
apply
(
model
)
model
=
mutator
.
apply
(
model
)
# run models
# run models
submit_models
(
model
)
submit_models
(
model
)
wait_models
(
model
)
self
.
running_models
[
self
.
model_id
]
=
model
self
.
tpe_sampler
.
receive_result
(
self
.
model_id
,
model
.
metric
)
self
.
model_id
+=
1
self
.
model_id
+=
1
_logger
.
info
(
'Strategy says: %s'
,
model
.
metric
)
else
:
except
Exception
:
time
.
sleep
(
2
)
_logger
.
error
(
logging
.
exception
(
'message'
))
_logger
.
warning
(
'num of running models: %d'
,
len
(
self
.
running_models
))
to_be_deleted
=
[]
for
_id
,
_model
in
self
.
running_models
.
items
():
if
is_stopped_exec
(
_model
):
if
_model
.
metric
is
not
None
:
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
_logger
.
warning
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
to_be_deleted
.
append
(
_id
)
for
_id
in
to_be_deleted
:
del
self
.
running_models
[
_id
]
nni/retiarii/trainer/pytorch/utils.py
View file @
08af7771
...
@@ -6,6 +6,7 @@ from collections import OrderedDict
...
@@ -6,6 +6,7 @@ from collections import OrderedDict
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
nni.retiarii.nn.pytorch
as
nn
from
nni.nas.pytorch.mutables
import
InputChoice
,
LayerChoice
from
nni.nas.pytorch.mutables
import
InputChoice
,
LayerChoice
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -157,7 +158,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
...
@@ -157,7 +158,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
List[Tuple[str, nn.Module]]
List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
A list from layer choice keys (names) and replaced modules.
"""
"""
return
_replace_module_with_type
(
root_module
,
init_fn
,
LayerChoice
,
modules
)
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
LayerChoice
,
nn
.
LayerChoice
),
modules
)
def
replace_input_choice
(
root_module
,
init_fn
,
modules
=
None
):
def
replace_input_choice
(
root_module
,
init_fn
,
modules
=
None
):
...
@@ -178,4 +179,4 @@ def replace_input_choice(root_module, init_fn, modules=None):
...
@@ -178,4 +179,4 @@ def replace_input_choice(root_module, init_fn, modules=None):
List[Tuple[str, nn.Module]]
List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
A list from layer choice keys (names) and replaced modules.
"""
"""
return
_replace_module_with_type
(
root_module
,
init_fn
,
InputChoice
,
modules
)
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
InputChoice
,
nn
.
InputChoice
),
modules
)
nni/retiarii/utils.py
View file @
08af7771
...
@@ -10,6 +10,11 @@ def import_(target: str, allow_none: bool = False) -> Any:
...
@@ -10,6 +10,11 @@ def import_(target: str, allow_none: bool = False) -> Any:
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
return
getattr
(
module
,
identifier
)
return
getattr
(
module
,
identifier
)
def
version_larger_equal
(
a
:
str
,
b
:
str
)
->
bool
:
# TODO: refactor later
a
=
a
.
split
(
'+'
)[
0
]
b
=
b
.
split
(
'+'
)[
0
]
return
tuple
(
map
(
int
,
a
.
split
(
'.'
)))
>=
tuple
(
map
(
int
,
b
.
split
(
'.'
)))
_records
=
{}
_records
=
{}
...
@@ -24,7 +29,7 @@ def add_record(key, value):
...
@@ -24,7 +29,7 @@ def add_record(key, value):
"""
"""
global
_records
global
_records
if
_records
is
not
None
:
if
_records
is
not
None
:
assert
key
not
in
_records
,
'{} already in _records'
.
format
(
key
)
#
assert key not in _records, '{} already in _records'.format(key)
_records
[
key
]
=
value
_records
[
key
]
=
value
...
...
test/retiarii_test/darts/darts_model.py
View file @
08af7771
...
@@ -55,7 +55,7 @@ class Node(nn.Module):
...
@@ -55,7 +55,7 @@ class Node(nn.Module):
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
)
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
)
]))
]))
self
.
drop_path
=
ops
.
DropPath
()
self
.
drop_path
=
ops
.
DropPath
()
self
.
input_switch
=
nn
.
InputChoice
(
n_chosen
=
2
)
self
.
input_switch
=
nn
.
InputChoice
(
n_candidates
=
num_prev_nodes
,
n_chosen
=
2
)
def
forward
(
self
,
prev_nodes
:
List
[
'Tensor'
])
->
'Tensor'
:
def
forward
(
self
,
prev_nodes
:
List
[
'Tensor'
])
->
'Tensor'
:
#assert self.ops.__len__() == len(prev_nodes)
#assert self.ops.__len__() == len(prev_nodes)
...
...
test/retiarii_test/darts/test.py
View file @
08af7771
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
pathlib
import
Path
from
pathlib
import
Path
from
nni.retiarii.experiment
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.experiment
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.strategies
import
TPEStrategy
from
nni.retiarii.strategies
import
TPEStrategy
,
RandomStrategy
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
from
darts_model
import
CNN
from
darts_model
import
CNN
...
@@ -18,7 +18,8 @@ if __name__ == '__main__':
...
@@ -18,7 +18,8 @@ if __name__ == '__main__':
optimizer_kwargs
=
{
"lr"
:
1e-3
},
optimizer_kwargs
=
{
"lr"
:
1e-3
},
trainer_kwargs
=
{
"max_epochs"
:
1
})
trainer_kwargs
=
{
"max_epochs"
:
1
})
simple_startegy
=
TPEStrategy
()
#simple_startegy = TPEStrategy()
simple_startegy
=
RandomStrategy
()
exp
=
RetiariiExperiment
(
base_model
,
trainer
,
[],
simple_startegy
)
exp
=
RetiariiExperiment
(
base_model
,
trainer
,
[],
simple_startegy
)
...
...
test/retiarii_test/darts/test_oneshot.py
0 → 100644
View file @
08af7771
import
json
import
numpy
as
np
import
os
import
sys
import
torch
import
torch.nn
as
nn
from
pathlib
import
Path
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
nni.retiarii.experiment
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.strategies
import
TPEStrategy
from
nni.retiarii.trainer.pytorch
import
DartsTrainer
from
darts_model
import
CNN
class
Cutout
(
object
):
def
__init__
(
self
,
length
):
self
.
length
=
length
def
__call__
(
self
,
img
):
h
,
w
=
img
.
size
(
1
),
img
.
size
(
2
)
mask
=
np
.
ones
((
h
,
w
),
np
.
float32
)
y
=
np
.
random
.
randint
(
h
)
x
=
np
.
random
.
randint
(
w
)
y1
=
np
.
clip
(
y
-
self
.
length
//
2
,
0
,
h
)
y2
=
np
.
clip
(
y
+
self
.
length
//
2
,
0
,
h
)
x1
=
np
.
clip
(
x
-
self
.
length
//
2
,
0
,
w
)
x2
=
np
.
clip
(
x
+
self
.
length
//
2
,
0
,
w
)
mask
[
y1
:
y2
,
x1
:
x2
]
=
0.
mask
=
torch
.
from_numpy
(
mask
)
mask
=
mask
.
expand_as
(
img
)
img
*=
mask
return
img
def
get_dataset
(
cls
,
cutout_length
=
0
):
MEAN
=
[
0.49139968
,
0.48215827
,
0.44653124
]
STD
=
[
0.24703233
,
0.24348505
,
0.26158768
]
transf
=
[
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
()
]
normalize
=
[
transforms
.
ToTensor
(),
transforms
.
Normalize
(
MEAN
,
STD
)
]
cutout
=
[]
if
cutout_length
>
0
:
cutout
.
append
(
Cutout
(
cutout_length
))
train_transform
=
transforms
.
Compose
(
transf
+
normalize
+
cutout
)
valid_transform
=
transforms
.
Compose
(
normalize
)
if
cls
==
"cifar10"
:
dataset_train
=
CIFAR10
(
root
=
"./data"
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
dataset_valid
=
CIFAR10
(
root
=
"./data"
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
else
:
raise
NotImplementedError
return
dataset_train
,
dataset_valid
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
""" Computes the precision@k for the specified values of k """
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
# one-hot case
if
target
.
ndimension
()
>
1
:
target
=
target
.
max
(
1
)[
1
]
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
dict
()
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
)
res
[
"acc{}"
.
format
(
k
)]
=
correct_k
.
mul_
(
1.0
/
batch_size
).
item
()
return
res
if
__name__
==
'__main__'
:
base_model
=
CNN
(
32
,
3
,
16
,
10
,
8
)
dataset_train
,
dataset_valid
=
get_dataset
(
"cifar10"
)
criterion
=
nn
.
CrossEntropyLoss
()
optim
=
torch
.
optim
.
SGD
(
base_model
.
parameters
(),
0.025
,
momentum
=
0.9
,
weight_decay
=
3.0E-4
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optim
,
50
,
eta_min
=
0.001
)
trainer
=
DartsTrainer
(
model
=
base_model
,
loss
=
criterion
,
metrics
=
lambda
output
,
target
:
accuracy
(
output
,
target
,
topk
=
(
1
,)),
optimizer
=
optim
,
num_epochs
=
50
,
dataset
=
dataset_train
,
batch_size
=
32
,
log_frequency
=
10
,
unrolled
=
False
)
exp
=
RetiariiExperiment
(
base_model
,
trainer
)
exp
.
run
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment