Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
063d6b74
Unverified
Commit
063d6b74
authored
Apr 26, 2021
by
SparkSnail
Committed by
GitHub
Apr 26, 2021
Browse files
Merge pull request #3580 from microsoft/v2.2
[do not Squash!] Merge V2.2 back to master
parents
08986c6b
e1295888
Changes
86
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
146 additions
and
130 deletions
+146
-130
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+7
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+15
-5
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+4
-4
nni/runtime/log.py
nni/runtime/log.py
+7
-3
nni/tools/nnictl/config_utils.py
nni/tools/nnictl/config_utils.py
+5
-5
nni/tools/nnictl/launcher.py
nni/tools/nnictl/launcher.py
+89
-22
nni/tools/nnictl/nnictl_utils.py
nni/tools/nnictl/nnictl_utils.py
+0
-35
pipelines/integration-test-adl.yml
pipelines/integration-test-adl.yml
+1
-1
pipelines/integration-test-frameworkcontroller.yml
pipelines/integration-test-frameworkcontroller.yml
+1
-1
pipelines/integration-test-kubeflow.yml
pipelines/integration-test-kubeflow.yml
+1
-1
test/config/examples/classic-nas-pytorch.yml
test/config/examples/classic-nas-pytorch.yml
+1
-1
test/config/examples/classic-nas-tf2.yml
test/config/examples/classic-nas-tf2.yml
+1
-1
test/config/integration_tests.yml
test/config/integration_tests.yml
+2
-22
test/config/integration_tests_tf2.yml
test/config/integration_tests_tf2.yml
+1
-23
test/config/metrics_test/trial.py
test/config/metrics_test/trial.py
+1
-0
test/config/tuners/regularized_evolution_tuner.yml
test/config/tuners/regularized_evolution_tuner.yml
+1
-1
test/scripts/nas.sh
test/scripts/nas.sh
+4
-4
test/ut/retiarii/test_strategy.py
test/ut/retiarii/test_strategy.py
+3
-0
ts/nni_manager/common/experimentConfig.ts
ts/nni_manager/common/experimentConfig.ts
+1
-0
ts/nni_manager/common/manager.ts
ts/nni_manager/common/manager.ts
+1
-1
No files found.
nni/retiarii/execution/interface.py
View file @
063d6b74
...
@@ -123,6 +123,13 @@ class AbstractExecutionEngine(ABC):
...
@@ -123,6 +123,13 @@ class AbstractExecutionEngine(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
budget_exhausted
(
self
)
->
bool
:
"""
Check whether user configured max trial number or max execution duration has been reached
"""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
"""
"""
...
...
nni/retiarii/experiment/pytorch.py
View file @
063d6b74
...
@@ -165,7 +165,8 @@ class RetiariiExperiment(Experiment):
...
@@ -165,7 +165,8 @@ class RetiariiExperiment(Experiment):
_logger
.
info
(
'Start strategy...'
)
_logger
.
info
(
'Start strategy...'
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
_logger
.
info
(
'Strategy exit'
)
_logger
.
info
(
'Strategy exit'
)
self
.
_dispatcher
.
mark_experiment_as_ending
()
# TODO: find out a proper way to show no more trial message on WebUI
#self._dispatcher.mark_experiment_as_ending()
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
"""
...
@@ -210,11 +211,12 @@ class RetiariiExperiment(Experiment):
...
@@ -210,11 +211,12 @@ class RetiariiExperiment(Experiment):
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
+
colorama
.
Style
.
RESET_ALL
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
+
colorama
.
Style
.
RESET_ALL
_logger
.
info
(
msg
)
_logger
.
info
(
msg
)
Thread
(
target
=
self
.
_check_exp_status
).
start
()
exp_status_checker
=
Thread
(
target
=
self
.
_check_exp_status
)
exp_status_checker
.
start
()
self
.
_start_strategy
()
self
.
_start_strategy
()
# TODO: the experiment should be completed, when strategy exits and there is no running job
# TODO: the experiment should be completed, when strategy exits and there is no running job
# _logger.info('Waiting for submitted trial jobs to finish...')
_logger
.
info
(
'Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...'
)
_logger
.
info
(
'Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...'
)
exp_status_checker
.
join
()
def
_create_dispatcher
(
self
):
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
return
self
.
_dispatcher
...
@@ -240,7 +242,12 @@ class RetiariiExperiment(Experiment):
...
@@ -240,7 +242,12 @@ class RetiariiExperiment(Experiment):
try
:
try
:
while
True
:
while
True
:
time
.
sleep
(
10
)
time
.
sleep
(
10
)
status
=
self
.
get_status
()
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
status
=
self
.
get_status
()
else
:
return
False
if
status
==
'DONE'
or
status
==
'STOPPED'
:
if
status
==
'DONE'
or
status
==
'STOPPED'
:
return
True
return
True
if
status
==
'ERROR'
:
if
status
==
'ERROR'
:
...
@@ -261,7 +268,10 @@ class RetiariiExperiment(Experiment):
...
@@ -261,7 +268,10 @@ class RetiariiExperiment(Experiment):
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
if
self
.
_proc
is
not
None
:
if
self
.
_proc
is
not
None
:
try
:
try
:
rest
.
delete
(
self
.
port
,
'/experiment'
)
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
rest
.
delete
(
self
.
port
,
'/experiment'
)
except
Exception
as
e
:
except
Exception
as
e
:
_logger
.
exception
(
e
)
_logger
.
exception
(
e
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
...
...
nni/retiarii/strategy/tpe_strategy.py
View file @
063d6b74
...
@@ -6,7 +6,7 @@ import time
...
@@ -6,7 +6,7 @@ import time
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
,
budget_exhausted
from
.base
import
BaseStrategy
from
.base
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -54,7 +54,7 @@ class TPEStrategy(BaseStrategy):
...
@@ -54,7 +54,7 @@ class TPEStrategy(BaseStrategy):
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
_logger
.
info
(
'TPE strategy has been started.'
)
_logger
.
info
(
'TPE strategy has been started.'
)
while
True
:
while
not
budget_exhausted
()
:
avail_resource
=
query_available_resources
()
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
if
avail_resource
>
0
:
model
=
base_model
model
=
base_model
...
@@ -70,13 +70,13 @@ class TPEStrategy(BaseStrategy):
...
@@ -70,13 +70,13 @@ class TPEStrategy(BaseStrategy):
else
:
else
:
time
.
sleep
(
2
)
time
.
sleep
(
2
)
_logger
.
warnin
g
(
'num of running models: %d'
,
len
(
self
.
running_models
))
_logger
.
debu
g
(
'num of running models: %d'
,
len
(
self
.
running_models
))
to_be_deleted
=
[]
to_be_deleted
=
[]
for
_id
,
_model
in
self
.
running_models
.
items
():
for
_id
,
_model
in
self
.
running_models
.
items
():
if
is_stopped_exec
(
_model
):
if
is_stopped_exec
(
_model
):
if
_model
.
metric
is
not
None
:
if
_model
.
metric
is
not
None
:
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
_logger
.
warnin
g
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
_logger
.
debu
g
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
to_be_deleted
.
append
(
_id
)
to_be_deleted
.
append
(
_id
)
for
_id
in
to_be_deleted
:
for
_id
in
to_be_deleted
:
del
self
.
running_models
[
_id
]
del
self
.
running_models
[
_id
]
nni/runtime/log.py
View file @
063d6b74
...
@@ -46,6 +46,7 @@ def init_logger() -> None:
...
@@ -46,6 +46,7 @@ def init_logger() -> None:
logging
.
getLogger
(
'filelock'
).
setLevel
(
logging
.
WARNING
)
logging
.
getLogger
(
'filelock'
).
setLevel
(
logging
.
WARNING
)
_exp_log_initialized
=
False
def
init_logger_experiment
()
->
None
:
def
init_logger_experiment
()
->
None
:
"""
"""
...
@@ -53,9 +54,12 @@ def init_logger_experiment() -> None:
...
@@ -53,9 +54,12 @@ def init_logger_experiment() -> None:
This function will get invoked after `init_logger()`.
This function will get invoked after `init_logger()`.
"""
"""
colorful_formatter
=
Formatter
(
log_format
,
time_format
)
global
_exp_log_initialized
colorful_formatter
.
format
=
_colorful_format
if
not
_exp_log_initialized
:
handlers
[
'_default_'
].
setFormatter
(
colorful_formatter
)
_exp_log_initialized
=
True
colorful_formatter
=
Formatter
(
log_format
,
time_format
)
colorful_formatter
.
format
=
_colorful_format
handlers
[
'_default_'
].
setFormatter
(
colorful_formatter
)
def
start_experiment_log
(
experiment_id
:
str
,
log_directory
:
Path
,
debug
:
bool
)
->
None
:
def
start_experiment_log
(
experiment_id
:
str
,
log_directory
:
Path
,
debug
:
bool
)
->
None
:
log_path
=
_prepare_log_dir
(
log_directory
)
/
'dispatcher.log'
log_path
=
_prepare_log_dir
(
log_directory
)
/
'dispatcher.log'
...
...
nni/tools/nnictl/config_utils.py
View file @
063d6b74
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
os
import
os
import
json
import
json
_tricks
import
shutil
import
shutil
import
sqlite3
import
sqlite3
import
time
import
time
...
@@ -92,7 +92,7 @@ class Config:
...
@@ -92,7 +92,7 @@ class Config:
'''refresh to get latest config'''
'''refresh to get latest config'''
sql
=
'select params from ExperimentProfile where id=? order by revision DESC'
sql
=
'select params from ExperimentProfile where id=? order by revision DESC'
args
=
(
self
.
experiment_id
,)
args
=
(
self
.
experiment_id
,)
self
.
config
=
config_v0_to_v1
(
json
.
loads
(
self
.
conn
.
cursor
().
execute
(
sql
,
args
).
fetchone
()[
0
]))
self
.
config
=
config_v0_to_v1
(
json
_tricks
.
loads
(
self
.
conn
.
cursor
().
execute
(
sql
,
args
).
fetchone
()[
0
]))
def
get_config
(
self
):
def
get_config
(
self
):
'''get a value according to key'''
'''get a value according to key'''
...
@@ -123,7 +123,7 @@ class Experiments:
...
@@ -123,7 +123,7 @@ class Experiments:
self
.
experiments
[
expId
][
'tag'
]
=
tag
self
.
experiments
[
expId
][
'tag'
]
=
tag
self
.
experiments
[
expId
][
'pid'
]
=
pid
self
.
experiments
[
expId
][
'pid'
]
=
pid
self
.
experiments
[
expId
][
'webuiUrl'
]
=
webuiUrl
self
.
experiments
[
expId
][
'webuiUrl'
]
=
webuiUrl
self
.
experiments
[
expId
][
'logDir'
]
=
logDir
self
.
experiments
[
expId
][
'logDir'
]
=
str
(
logDir
)
self
.
write_file
()
self
.
write_file
()
def
update_experiment
(
self
,
expId
,
key
,
value
):
def
update_experiment
(
self
,
expId
,
key
,
value
):
...
@@ -155,7 +155,7 @@ class Experiments:
...
@@ -155,7 +155,7 @@ class Experiments:
'''save config to local file'''
'''save config to local file'''
try
:
try
:
with
open
(
self
.
experiment_file
,
'w'
)
as
file
:
with
open
(
self
.
experiment_file
,
'w'
)
as
file
:
json
.
dump
(
self
.
experiments
,
file
,
indent
=
4
)
json
_tricks
.
dump
(
self
.
experiments
,
file
,
indent
=
4
)
except
IOError
as
error
:
except
IOError
as
error
:
print
(
'Error:'
,
error
)
print
(
'Error:'
,
error
)
return
''
return
''
...
@@ -165,7 +165,7 @@ class Experiments:
...
@@ -165,7 +165,7 @@ class Experiments:
if
os
.
path
.
exists
(
self
.
experiment_file
):
if
os
.
path
.
exists
(
self
.
experiment_file
):
try
:
try
:
with
open
(
self
.
experiment_file
,
'r'
)
as
file
:
with
open
(
self
.
experiment_file
,
'r'
)
as
file
:
return
json
.
load
(
file
)
return
json
_tricks
.
load
(
file
)
except
ValueError
:
except
ValueError
:
return
{}
return
{}
return
{}
return
{}
nni/tools/nnictl/launcher.py
View file @
063d6b74
...
@@ -119,12 +119,51 @@ def set_trial_config(experiment_config, port, config_file_name):
...
@@ -119,12 +119,51 @@ def set_trial_config(experiment_config, port, config_file_name):
def
set_adl_config
(
experiment_config
,
port
,
config_file_name
):
def
set_adl_config
(
experiment_config
,
port
,
config_file_name
):
'''set adl configuration'''
'''set adl configuration'''
adl_config_data
=
dict
()
# hack for supporting v2 config, need refactor
adl_config_data
[
'adl_config'
]
=
{}
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
adl_config_data
),
REST_TIME_OUT
)
err_message
=
None
if
not
response
or
not
response
.
status_code
==
200
:
if
response
is
not
None
:
err_message
=
response
.
text
_
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
set_V1_common_config
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
if
not
result
:
return
result
,
message
return
result
,
message
#set trial_config
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
None
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
None
def
validate_response
(
response
,
config_file_name
):
err_message
=
None
if
not
response
or
not
response
.
status_code
==
200
:
if
response
is
not
None
:
err_message
=
response
.
text
_
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
print_error
(
'Error:'
+
err_message
)
exit
(
1
)
# hack to fix v1 version_check and log_collection bug, need refactor
def
set_V1_common_config
(
experiment_config
,
port
,
config_file_name
):
version_check
=
True
#debug mode should disable version check
if
experiment_config
.
get
(
'debug'
)
is
not
None
:
version_check
=
not
experiment_config
.
get
(
'debug'
)
#validate version check
if
experiment_config
.
get
(
'versionCheck'
)
is
not
None
:
version_check
=
experiment_config
.
get
(
'versionCheck'
)
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
({
'version_check'
:
version_check
}),
REST_TIME_OUT
)
validate_response
(
response
,
config_file_name
)
if
experiment_config
.
get
(
'logCollection'
):
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
({
'log_collection'
:
experiment_config
.
get
(
'logCollection'
)}),
REST_TIME_OUT
)
validate_response
(
response
,
config_file_name
)
def
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
):
def
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
):
'''set nniManagerIp'''
'''set nniManagerIp'''
if
experiment_config
.
get
(
'nniManagerIp'
)
is
None
:
if
experiment_config
.
get
(
'nniManagerIp'
)
is
None
:
...
@@ -155,6 +194,7 @@ def set_kubeflow_config(experiment_config, port, config_file_name):
...
@@ -155,6 +194,7 @@ def set_kubeflow_config(experiment_config, port, config_file_name):
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
return
False
,
err_message
set_V1_common_config
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
if
not
result
:
return
result
,
message
return
result
,
message
...
@@ -174,6 +214,7 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name):
...
@@ -174,6 +214,7 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name):
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
return
False
,
err_message
set_V1_common_config
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
if
not
result
:
return
result
,
message
return
result
,
message
...
@@ -200,9 +241,13 @@ def set_experiment_v1(experiment_config, mode, port, config_file_name):
...
@@ -200,9 +241,13 @@ def set_experiment_v1(experiment_config, mode, port, config_file_name):
request_data
[
'experimentName'
]
=
experiment_config
[
'experimentName'
]
request_data
[
'experimentName'
]
=
experiment_config
[
'experimentName'
]
request_data
[
'trialConcurrency'
]
=
experiment_config
[
'trialConcurrency'
]
request_data
[
'trialConcurrency'
]
=
experiment_config
[
'trialConcurrency'
]
request_data
[
'maxExecDuration'
]
=
experiment_config
[
'maxExecDuration'
]
request_data
[
'maxExecDuration'
]
=
experiment_config
[
'maxExecDuration'
]
request_data
[
'maxExperimentDuration'
]
=
str
(
experiment_config
[
'maxExecDuration'
])
+
's'
request_data
[
'maxTrialNum'
]
=
experiment_config
[
'maxTrialNum'
]
request_data
[
'maxTrialNum'
]
=
experiment_config
[
'maxTrialNum'
]
request_data
[
'maxTrialNumber'
]
=
experiment_config
[
'maxTrialNum'
]
request_data
[
'searchSpace'
]
=
experiment_config
.
get
(
'searchSpace'
)
request_data
[
'searchSpace'
]
=
experiment_config
.
get
(
'searchSpace'
)
request_data
[
'trainingServicePlatform'
]
=
experiment_config
.
get
(
'trainingServicePlatform'
)
request_data
[
'trainingServicePlatform'
]
=
experiment_config
.
get
(
'trainingServicePlatform'
)
# hack for hotfix, fix config.trainingService undefined error, need refactor
request_data
[
'trainingService'
]
=
{
'platform'
:
experiment_config
.
get
(
'trainingServicePlatform'
)}
if
experiment_config
.
get
(
'description'
):
if
experiment_config
.
get
(
'description'
):
request_data
[
'description'
]
=
experiment_config
[
'description'
]
request_data
[
'description'
]
=
experiment_config
[
'description'
]
if
experiment_config
.
get
(
'multiPhase'
):
if
experiment_config
.
get
(
'multiPhase'
):
...
@@ -319,7 +364,10 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
...
@@ -319,7 +364,10 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
if
package_name
in
[
'SMAC'
,
'BOHB'
,
'PPOTuner'
]:
if
package_name
in
[
'SMAC'
,
'BOHB'
,
'PPOTuner'
]:
print_error
(
f
'The dependencies for
{
package_name
}
can be installed through pip install nni[
{
package_name
}
]'
)
print_error
(
f
'The dependencies for
{
package_name
}
can be installed through pip install nni[
{
package_name
}
]'
)
raise
raise
log_dir
=
experiment_config
[
'logDir'
]
if
experiment_config
.
get
(
'logDir'
)
else
NNI_HOME_DIR
if
config_version
==
1
:
log_dir
=
experiment_config
[
'logDir'
]
if
experiment_config
.
get
(
'logDir'
)
else
NNI_HOME_DIR
else
:
log_dir
=
experiment_config
[
'experimentWorkingDirectory'
]
if
experiment_config
.
get
(
'experimentWorkingDirectory'
)
else
NNI_HOME_DIR
log_level
=
experiment_config
[
'logLevel'
]
if
experiment_config
.
get
(
'logLevel'
)
else
None
log_level
=
experiment_config
[
'logLevel'
]
if
experiment_config
.
get
(
'logLevel'
)
else
None
#view experiment mode do not need debug function, when view an experiment, there will be no new logs created
#view experiment mode do not need debug function, when view an experiment, there will be no new logs created
foreground
=
False
foreground
=
False
...
@@ -330,6 +378,8 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
...
@@ -330,6 +378,8 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
# start rest server
# start rest server
if
config_version
==
1
:
if
config_version
==
1
:
platform
=
experiment_config
[
'trainingServicePlatform'
]
platform
=
experiment_config
[
'trainingServicePlatform'
]
elif
isinstance
(
experiment_config
[
'trainingService'
],
list
):
platform
=
'hybrid'
else
:
else
:
platform
=
experiment_config
[
'trainingService'
][
'platform'
]
platform
=
experiment_config
[
'trainingService'
][
'platform'
]
...
@@ -349,14 +399,14 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
...
@@ -349,14 +399,14 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
code_dir
=
expand_annotations
(
experiment_config
[
'trial'
][
'codeDir'
],
path
,
nas_mode
=
nas_mode
)
code_dir
=
expand_annotations
(
experiment_config
[
'trial'
][
'codeDir'
],
path
,
nas_mode
=
nas_mode
)
experiment_config
[
'trial'
][
'codeDir'
]
=
code_dir
experiment_config
[
'trial'
][
'codeDir'
]
=
code_dir
search_space
=
generate_search_space
(
code_dir
)
search_space
=
generate_search_space
(
code_dir
)
experiment_config
[
'searchSpace'
]
=
json
.
dumps
(
search_space
)
experiment_config
[
'searchSpace'
]
=
search_space
assert
search_space
,
ERROR_INFO
%
'Generated search space is empty'
assert
search_space
,
ERROR_INFO
%
'Generated search space is empty'
elif
config_version
==
1
:
elif
config_version
==
1
:
if
experiment_config
.
get
(
'searchSpacePath'
):
if
experiment_config
.
get
(
'searchSpacePath'
):
search_space
=
get_json_content
(
experiment_config
.
get
(
'searchSpacePath'
))
search_space
=
get_json_content
(
experiment_config
.
get
(
'searchSpacePath'
))
experiment_config
[
'searchSpace'
]
=
json
.
dumps
(
search_space
)
experiment_config
[
'searchSpace'
]
=
search_space
else
:
else
:
experiment_config
[
'searchSpace'
]
=
json
.
dumps
(
''
)
experiment_config
[
'searchSpace'
]
=
''
# check rest server
# check rest server
running
,
_
=
check_rest_server
(
args
.
port
)
running
,
_
=
check_rest_server
(
args
.
port
)
...
@@ -411,6 +461,21 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
...
@@ -411,6 +461,21 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
kill_command
(
rest_process
.
pid
)
kill_command
(
rest_process
.
pid
)
print_normal
(
'Stopping experiment...'
)
print_normal
(
'Stopping experiment...'
)
def
_validate_v1
(
config
,
path
):
try
:
validate_all_content
(
config
,
path
)
except
Exception
as
e
:
print_error
(
f
'Config V1 validation failed:
{
repr
(
e
)
}
'
)
exit
(
1
)
def
_validate_v2
(
config
,
path
):
base_path
=
Path
(
path
).
parent
try
:
conf
=
ExperimentConfig
(
_base_path
=
base_path
,
**
config
)
return
conf
.
json
()
except
Exception
as
e
:
print_error
(
f
'Config V2 validation failed:
{
repr
(
e
)
}
'
)
def
create_experiment
(
args
):
def
create_experiment
(
args
):
'''start a new experiment'''
'''start a new experiment'''
experiment_id
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
8
))
experiment_id
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
8
))
...
@@ -420,23 +485,23 @@ def create_experiment(args):
...
@@ -420,23 +485,23 @@ def create_experiment(args):
exit
(
1
)
exit
(
1
)
config_yml
=
get_yml_content
(
config_path
)
config_yml
=
get_yml_content
(
config_path
)
try
:
if
'trainingServicePlatform'
in
config_yml
:
config
=
ExperimentConfig
(
_base_path
=
Path
(
config_path
).
parent
,
**
config_
yml
)
_validate_v1
(
config_yml
,
config_
path
)
config_v2
=
config
.
json
()
platform
=
config_yml
[
'trainingServicePlatform'
]
except
Exception
as
error_v2
:
if
platform
in
k8s_training_services
:
print_warning
(
'Validation with V2 schema failed. Trying to convert from V1 format...'
)
schema
=
1
try
:
config_v1
=
config_yml
validate_all_content
(
config_yml
,
config_path
)
else
:
except
Exception
as
error_v1
:
schema
=
2
print_error
(
f
'Convert from v1 format failed:
{
repr
(
error_v1
)
}
'
)
from
nni.experiment.config
import
convert
print_error
(
f
'Config in v2 format validation failed:
{
repr
(
error_v2
)
}
'
)
config_v2
=
convert
.
to_v2
(
config_yml
).
json
(
)
exit
(
1
)
else
:
from
nni.experiment.config
import
convert
config_v2
=
_validate_v2
(
config_yml
,
config_path
)
config_v2
=
convert
.
to_v2
(
config_yml
).
json
()
schema
=
2
try
:
try
:
if
getattr
(
config_v2
[
'trainingService'
],
'platform'
,
None
)
in
k8s_training_services
:
if
schema
==
1
:
launch_experiment
(
args
,
config_
yml
,
'new'
,
experiment_id
,
1
)
launch_experiment
(
args
,
config_
v1
,
'new'
,
experiment_id
,
1
)
else
:
else
:
launch_experiment
(
args
,
config_v2
,
'new'
,
experiment_id
,
2
)
launch_experiment
(
args
,
config_v2
,
'new'
,
experiment_id
,
2
)
except
Exception
as
exception
:
except
Exception
as
exception
:
...
@@ -470,10 +535,12 @@ def manage_stopped_experiment(args, mode):
...
@@ -470,10 +535,12 @@ def manage_stopped_experiment(args, mode):
experiments_config
.
update_experiment
(
args
.
id
,
'port'
,
args
.
port
)
experiments_config
.
update_experiment
(
args
.
id
,
'port'
,
args
.
port
)
assert
'trainingService'
in
experiment_config
or
'trainingServicePlatform'
in
experiment_config
assert
'trainingService'
in
experiment_config
or
'trainingServicePlatform'
in
experiment_config
try
:
try
:
if
'trainingService'
in
experiment_config
:
if
'trainingServicePlatform'
in
experiment_config
:
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
,
2
)
experiment_config
[
'logDir'
]
=
experiments_dict
[
args
.
id
][
'logDir'
]
else
:
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
,
1
)
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
,
1
)
else
:
experiment_config
[
'experimentWorkingDirectory'
]
=
experiments_dict
[
args
.
id
][
'logDir'
]
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
,
2
)
except
Exception
as
exception
:
except
Exception
as
exception
:
restServerPid
=
Experiments
().
get_all_experiments
().
get
(
experiment_id
,
{}).
get
(
'pid'
)
restServerPid
=
Experiments
().
get_all_experiments
().
get
(
experiment_id
,
{}).
get
(
'pid'
)
if
restServerPid
:
if
restServerPid
:
...
...
nni/tools/nnictl/nnictl_utils.py
View file @
063d6b74
...
@@ -13,7 +13,6 @@ from functools import cmp_to_key
...
@@ -13,7 +13,6 @@ from functools import cmp_to_key
import
traceback
import
traceback
from
datetime
import
datetime
,
timezone
from
datetime
import
datetime
,
timezone
from
subprocess
import
Popen
from
subprocess
import
Popen
from
pyhdfs
import
HdfsClient
from
nni.tools.annotation
import
expand_annotations
from
nni.tools.annotation
import
expand_annotations
import
nni_node
# pylint: disable=import-error
import
nni_node
# pylint: disable=import-error
from
.rest_utils
import
rest_get
,
rest_delete
,
check_rest_server_quick
,
check_response
from
.rest_utils
import
rest_get
,
rest_delete
,
check_rest_server_quick
,
check_response
...
@@ -501,30 +500,6 @@ def remote_clean(machine_list, experiment_id=None):
...
@@ -501,30 +500,6 @@ def remote_clean(machine_list, experiment_id=None):
print_normal
(
'removing folder {0}'
.
format
(
host
+
':'
+
str
(
port
)
+
remote_dir
))
print_normal
(
'removing folder {0}'
.
format
(
host
+
':'
+
str
(
port
)
+
remote_dir
))
remove_remote_directory
(
sftp
,
remote_dir
)
remove_remote_directory
(
sftp
,
remote_dir
)
def
hdfs_clean
(
host
,
user_name
,
output_dir
,
experiment_id
=
None
):
'''clean up hdfs data'''
hdfs_client
=
HdfsClient
(
hosts
=
'{0}:80'
.
format
(
host
),
user_name
=
user_name
,
webhdfs_path
=
'/webhdfs/api/v1'
,
timeout
=
5
)
if
experiment_id
:
full_path
=
'/'
+
'/'
.
join
([
user_name
,
'nni'
,
'experiments'
,
experiment_id
])
else
:
full_path
=
'/'
+
'/'
.
join
([
user_name
,
'nni'
,
'experiments'
])
print_normal
(
'removing folder {0} in hdfs'
.
format
(
full_path
))
hdfs_client
.
delete
(
full_path
,
recursive
=
True
)
if
output_dir
:
pattern
=
re
.
compile
(
'hdfs://(?P<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?P<baseDir>/.*)?'
)
match_result
=
pattern
.
match
(
output_dir
)
if
match_result
:
output_host
=
match_result
.
group
(
'host'
)
output_dir
=
match_result
.
group
(
'baseDir'
)
#check if the host is valid
if
output_host
!=
host
:
print_warning
(
'The host in {0} is not consistent with {1}'
.
format
(
output_dir
,
host
))
else
:
if
experiment_id
:
output_dir
=
output_dir
+
'/'
+
experiment_id
print_normal
(
'removing folder {0} in hdfs'
.
format
(
output_dir
))
hdfs_client
.
delete
(
output_dir
,
recursive
=
True
)
def
experiment_clean
(
args
):
def
experiment_clean
(
args
):
'''clean up the experiment data'''
'''clean up the experiment data'''
experiment_id_list
=
[]
experiment_id_list
=
[]
...
@@ -556,11 +531,6 @@ def experiment_clean(args):
...
@@ -556,11 +531,6 @@ def experiment_clean(args):
if
platform
==
'remote'
:
if
platform
==
'remote'
:
machine_list
=
experiment_config
.
get
(
'machineList'
)
machine_list
=
experiment_config
.
get
(
'machineList'
)
remote_clean
(
machine_list
,
experiment_id
)
remote_clean
(
machine_list
,
experiment_id
)
elif
platform
==
'pai'
:
host
=
experiment_config
.
get
(
'paiConfig'
).
get
(
'host'
)
user_name
=
experiment_config
.
get
(
'paiConfig'
).
get
(
'userName'
)
output_dir
=
experiment_config
.
get
(
'trial'
).
get
(
'outputDir'
)
hdfs_clean
(
host
,
user_name
,
output_dir
,
experiment_id
)
elif
platform
!=
'local'
:
elif
platform
!=
'local'
:
# TODO: support all platforms
# TODO: support all platforms
print_warning
(
'platform {0} clean up not supported yet.'
.
format
(
platform
))
print_warning
(
'platform {0} clean up not supported yet.'
.
format
(
platform
))
...
@@ -632,11 +602,6 @@ def platform_clean(args):
...
@@ -632,11 +602,6 @@ def platform_clean(args):
if
platform
==
'remote'
:
if
platform
==
'remote'
:
machine_list
=
config_content
.
get
(
'machineList'
)
machine_list
=
config_content
.
get
(
'machineList'
)
remote_clean
(
machine_list
)
remote_clean
(
machine_list
)
elif
platform
==
'pai'
:
host
=
config_content
.
get
(
'paiConfig'
).
get
(
'host'
)
user_name
=
config_content
.
get
(
'paiConfig'
).
get
(
'userName'
)
output_dir
=
config_content
.
get
(
'trial'
).
get
(
'outputDir'
)
hdfs_clean
(
host
,
user_name
,
output_dir
)
print_normal
(
'Done.'
)
print_normal
(
'Done.'
)
def
experiment_list
(
args
):
def
experiment_list
(
args
):
...
...
pipelines/integration-test-adl.yml
View file @
063d6b74
...
@@ -59,5 +59,5 @@ jobs:
...
@@ -59,5 +59,5 @@ jobs:
--checkpoint_storage_class $(checkpoint_storage_class) \
--checkpoint_storage_class $(checkpoint_storage_class) \
--checkpoint_storage_size $(checkpoint_storage_size) \
--checkpoint_storage_size $(checkpoint_storage_size) \
--nni_manager_ip $(nni_manager_ip)
--nni_manager_ip $(nni_manager_ip)
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts adl
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts adl
--exclude multi-phase,multi-thread
displayName
:
Integration test
displayName
:
Integration test
pipelines/integration-test-frameworkcontroller.yml
View file @
063d6b74
...
@@ -48,5 +48,5 @@ jobs:
...
@@ -48,5 +48,5 @@ jobs:
--azs_share nni \
--azs_share nni \
--nni_docker_image nnidev/nni-nightly \
--nni_docker_image nnidev/nni-nightly \
--nni_manager_ip $(manager_ip)
--nni_manager_ip $(manager_ip)
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts frameworkcontroller --exclude multi-phase
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts frameworkcontroller --exclude multi-phase
,multi-thread
displayName
:
Integration test
displayName
:
Integration test
pipelines/integration-test-kubeflow.yml
View file @
063d6b74
...
@@ -58,5 +58,5 @@ jobs:
...
@@ -58,5 +58,5 @@ jobs:
--azs_share nni \
--azs_share nni \
--nni_docker_image nnidev/nni-nightly \
--nni_docker_image nnidev/nni-nightly \
--nni_manager_ip $(manager_ip)
--nni_manager_ip $(manager_ip)
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts kubeflow --exclude multi-phase
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts kubeflow --exclude multi-phase
,multi-thread
displayName
:
Integration test
displayName
:
Integration test
test/config/examples/classic-nas-pytorch.yml
View file @
063d6b74
...
@@ -11,7 +11,7 @@ tuner:
...
@@ -11,7 +11,7 @@ tuner:
optimize_mode
:
maximize
optimize_mode
:
maximize
trial
:
trial
:
command
:
python3 mnist.py --epochs
1
command
:
python3 mnist.py --epochs
1
codeDir
:
../../../examples/nas/classic_nas
codeDir
:
../../../examples/nas/
legacy/
classic_nas
gpuNum
:
0
gpuNum
:
0
useAnnotation
:
false
useAnnotation
:
false
...
...
test/config/examples/classic-nas-tf2.yml
View file @
063d6b74
...
@@ -11,7 +11,7 @@ tuner:
...
@@ -11,7 +11,7 @@ tuner:
optimize_mode
:
maximize
optimize_mode
:
maximize
trial
:
trial
:
command
:
python3 train.py --epochs
1
command
:
python3 train.py --epochs
1
codeDir
:
../../../examples/nas/classic_nas-tf
codeDir
:
../../../examples/nas/
legacy/
classic_nas-tf
gpuNum
:
0
gpuNum
:
0
useAnnotation
:
false
useAnnotation
:
false
...
...
test/config/integration_tests.yml
View file @
063d6b74
...
@@ -40,6 +40,7 @@ testCases:
...
@@ -40,6 +40,7 @@ testCases:
-
name
:
mnist-tensorflow
-
name
:
mnist-tensorflow
configFile
:
test/config/examples/mnist-tfv2.yml
configFile
:
test/config/examples/mnist-tfv2.yml
config
:
config
:
maxExecDuration
:
10m
# This example will use longger time in remote mode, set max_duration to 10m to avoid timeout error.
maxTrialNum
:
1
maxTrialNum
:
1
trialConcurrency
:
1
trialConcurrency
:
1
trainingService
:
local remote
# FIXME: timeout on pai, looks like tensorflow failed to link CUDA
trainingService
:
local remote
# FIXME: timeout on pai, looks like tensorflow failed to link CUDA
...
@@ -84,7 +85,7 @@ testCases:
...
@@ -84,7 +85,7 @@ testCases:
-
name
:
classic-nas-gen-ss
-
name
:
classic-nas-gen-ss
configFile
:
test/config/examples/classic-nas-pytorch.yml
configFile
:
test/config/examples/classic-nas-pytorch.yml
launchCommand
:
nnictl ss_gen --trial_command="python3 mnist.py --epochs 1" --trial_dir=../examples/nas/classic_nas --file=config/examples/nni-nas-search-space.json
launchCommand
:
nnictl ss_gen --trial_command="python3 mnist.py --epochs 1" --trial_dir=../examples/nas/
legacy/
classic_nas --file=config/examples/nni-nas-search-space.json
stopCommand
:
stopCommand
:
experimentStatusCheck
:
False
experimentStatusCheck
:
False
trainingService
:
local
trainingService
:
local
...
@@ -170,27 +171,6 @@ testCases:
...
@@ -170,27 +171,6 @@ testCases:
-
name
:
multi-thread
-
name
:
multi-thread
configFile
:
test/config/multi_thread/config.yml
configFile
:
test/config/multi_thread/config.yml
-
name
:
multi-phase-batch
configFile
:
test/config/multi_phase/batch.yml
config
:
# for batch tuner, maxTrialNum can not exceed length of search space
maxTrialNum
:
2
trialConcurrency
:
2
-
name
:
multi-phase-evolution
configFile
:
test/config/multi_phase/evolution.yml
-
name
:
multi-phase-grid
configFile
:
test/config/multi_phase/grid.yml
config
:
maxTrialNum
:
2
trialConcurrency
:
2
-
name
:
multi-phase-metis
configFile
:
test/config/multi_phase/metis.yml
-
name
:
multi-phase-tpe
configFile
:
test/config/multi_phase/tpe.yml
#########################################################################
#########################################################################
# nni assessor test
# nni assessor test
...
...
test/config/integration_tests_tf2.yml
View file @
063d6b74
...
@@ -58,7 +58,7 @@ testCases:
...
@@ -58,7 +58,7 @@ testCases:
-
name
:
classic-nas-gen-ss
-
name
:
classic-nas-gen-ss
configFile
:
test/config/examples/classic-nas-tf2.yml
configFile
:
test/config/examples/classic-nas-tf2.yml
launchCommand
:
nnictl ss_gen --trial_command="python3 train.py --epochs 1" --trial_dir=../examples/nas/classic_nas-tf --file=config/examples/nni-nas-search-space-tf2.json
launchCommand
:
nnictl ss_gen --trial_command="python3 train.py --epochs 1" --trial_dir=../examples/nas/
legacy/
classic_nas-tf --file=config/examples/nni-nas-search-space-tf2.json
stopCommand
:
stopCommand
:
experimentStatusCheck
:
False
experimentStatusCheck
:
False
trainingService
:
local
trainingService
:
local
...
@@ -135,28 +135,6 @@ testCases:
...
@@ -135,28 +135,6 @@ testCases:
-
name
:
multi-thread
-
name
:
multi-thread
configFile
:
test/config/multi_thread/config.yml
configFile
:
test/config/multi_thread/config.yml
-
name
:
multi-phase-batch
configFile
:
test/config/multi_phase/batch.yml
config
:
# for batch tuner, maxTrialNum can not exceed length of search space
maxTrialNum
:
2
trialConcurrency
:
2
-
name
:
multi-phase-evolution
configFile
:
test/config/multi_phase/evolution.yml
-
name
:
multi-phase-grid
configFile
:
test/config/multi_phase/grid.yml
config
:
maxTrialNum
:
2
trialConcurrency
:
2
-
name
:
multi-phase-metis
configFile
:
test/config/multi_phase/metis.yml
-
name
:
multi-phase-tpe
configFile
:
test/config/multi_phase/tpe.yml
#########################################################################
#########################################################################
# nni assessor test
# nni assessor test
#########################################################################
#########################################################################
...
...
test/config/metrics_test/trial.py
View file @
063d6b74
...
@@ -19,6 +19,7 @@ if __name__ == '__main__':
...
@@ -19,6 +19,7 @@ if __name__ == '__main__':
nni
.
get_next_parameter
()
nni
.
get_next_parameter
()
with
open
(
result_file
,
'r'
)
as
f
:
with
open
(
result_file
,
'r'
)
as
f
:
m
=
json
.
load
(
f
)
m
=
json
.
load
(
f
)
time
.
sleep
(
5
)
for
v
in
m
[
'intermediate_result'
]:
for
v
in
m
[
'intermediate_result'
]:
time
.
sleep
(
1
)
time
.
sleep
(
1
)
print
(
'report_intermediate_result:'
,
v
)
print
(
'report_intermediate_result:'
,
v
)
...
...
test/config/tuners/regularized_evolution_tuner.yml
View file @
063d6b74
...
@@ -9,7 +9,7 @@ tuner:
...
@@ -9,7 +9,7 @@ tuner:
classArgs
:
classArgs
:
optimize_mode
:
maximize
optimize_mode
:
maximize
trial
:
trial
:
codeDir
:
../../../examples/nas/classic_nas
codeDir
:
../../../examples/nas/
legacy/
classic_nas
command
:
python3 mnist.py --epochs
1
command
:
python3 mnist.py --epochs
1
gpuNum
:
0
gpuNum
:
0
...
...
test/scripts/nas.sh
View file @
063d6b74
...
@@ -7,7 +7,7 @@ echo "===========================Testing: NAS==========================="
...
@@ -7,7 +7,7 @@ echo "===========================Testing: NAS==========================="
EXAMPLE_DIR
=
${
CWD
}
/../examples/nas
EXAMPLE_DIR
=
${
CWD
}
/../examples/nas
echo
"testing nnictl ss_gen (classic nas)..."
echo
"testing nnictl ss_gen (classic nas)..."
cd
$EXAMPLE_DIR
/classic_nas
cd
$EXAMPLE_DIR
/
legacy/
classic_nas
SEARCH_SPACE_JSON
=
nni_auto_gen_search_space.json
SEARCH_SPACE_JSON
=
nni_auto_gen_search_space.json
if
[
-f
$SEARCH_SPACE_JSON
]
;
then
if
[
-f
$SEARCH_SPACE_JSON
]
;
then
rm
$SEARCH_SPACE_JSON
rm
$SEARCH_SPACE_JSON
...
@@ -19,12 +19,12 @@ if [ ! -f $SEARCH_SPACE_JSON ]; then
...
@@ -19,12 +19,12 @@ if [ ! -f $SEARCH_SPACE_JSON ]; then
fi
fi
echo
"testing darts..."
echo
"testing darts..."
cd
$EXAMPLE_DIR
/darts
cd
$EXAMPLE_DIR
/
oneshot/
darts
python3 search.py
--epochs
1
--channels
2
--layers
4
python3 search.py
--epochs
1
--channels
2
--layers
4
python3 retrain.py
--arc-checkpoint
./checkpoint.json
--layers
4
--epochs
1
python3 retrain.py
--arc-checkpoint
./checkpoint.json
--layers
4
--epochs
1
echo
"testing enas..."
echo
"testing enas..."
cd
$EXAMPLE_DIR
/enas
cd
$EXAMPLE_DIR
/
oneshot/
enas
python3 search.py
--search-for
macro
--epochs
1
python3 search.py
--search-for
macro
--epochs
1
python3 search.py
--search-for
micro
--epochs
1
python3 search.py
--search-for
micro
--epochs
1
...
@@ -34,5 +34,5 @@ python3 search.py --search-for micro --epochs 1
...
@@ -34,5 +34,5 @@ python3 search.py --search-for micro --epochs 1
#python3 train.py
#python3 train.py
echo
"testing pdarts..."
echo
"testing pdarts..."
cd
$EXAMPLE_DIR
/pdarts
cd
$EXAMPLE_DIR
/
legacy/
pdarts
python3 search.py
--epochs
1
--channels
4
--nodes
2
--log-frequency
10
--add_layers
0
--add_layers
1
--dropped_ops
3
--dropped_ops
3
python3 search.py
--epochs
1
--channels
4
--nodes
2
--log-frequency
10
--add_layers
0
--add_layers
1
--dropped_ops
3
--dropped_ops
3
test/ut/retiarii/test_strategy.py
View file @
063d6b74
...
@@ -43,6 +43,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
...
@@ -43,6 +43,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
return
self
.
_resource_left
return
self
.
_resource_left
def
budget_exhausted
(
self
)
->
bool
:
pass
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
pass
pass
...
...
ts/nni_manager/common/experimentConfig.ts
View file @
063d6b74
...
@@ -65,6 +65,7 @@ export interface AmlConfig extends TrainingServiceConfig {
...
@@ -65,6 +65,7 @@ export interface AmlConfig extends TrainingServiceConfig {
workspaceName
:
string
;
workspaceName
:
string
;
computeTarget
:
string
;
computeTarget
:
string
;
dockerImage
:
string
;
dockerImage
:
string
;
maxTrialNumberPerGpu
:
number
;
}
}
/* Kubeflow */
/* Kubeflow */
...
...
ts/nni_manager/common/manager.ts
View file @
063d6b74
...
@@ -8,7 +8,7 @@ import { TrialJobStatus, LogType } from './trainingService';
...
@@ -8,7 +8,7 @@ import { TrialJobStatus, LogType } from './trainingService';
import
{
ExperimentConfig
}
from
'
./experimentConfig
'
;
import
{
ExperimentConfig
}
from
'
./experimentConfig
'
;
type
ProfileUpdateType
=
'
TRIAL_CONCURRENCY
'
|
'
MAX_EXEC_DURATION
'
|
'
SEARCH_SPACE
'
|
'
MAX_TRIAL_NUM
'
;
type
ProfileUpdateType
=
'
TRIAL_CONCURRENCY
'
|
'
MAX_EXEC_DURATION
'
|
'
SEARCH_SPACE
'
|
'
MAX_TRIAL_NUM
'
;
type
ExperimentStatus
=
'
INITIALIZED
'
|
'
RUNNING
'
|
'
ERROR
'
|
'
STOPPING
'
|
'
STOPPED
'
|
'
DONE
'
|
'
NO_MORE_TRIAL
'
|
'
TUNER_NO_MORE_TRIAL
'
;
type
ExperimentStatus
=
'
INITIALIZED
'
|
'
RUNNING
'
|
'
ERROR
'
|
'
STOPPING
'
|
'
STOPPED
'
|
'
DONE
'
|
'
NO_MORE_TRIAL
'
|
'
TUNER_NO_MORE_TRIAL
'
|
'
VIEWED
'
;
namespace
ExperimentStartUpMode
{
namespace
ExperimentStartUpMode
{
export
const
NEW
=
'
new
'
;
export
const
NEW
=
'
new
'
;
export
const
RESUME
=
'
resume
'
;
export
const
RESUME
=
'
resume
'
;
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment