Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
063d6b74
Unverified
Commit
063d6b74
authored
Apr 26, 2021
by
SparkSnail
Committed by
GitHub
Apr 26, 2021
Browse files
Merge pull request #3580 from microsoft/v2.2
[do not Squash!] Merge V2.2 back to master
parents
08986c6b
e1295888
Changes
86
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
146 additions
and
130 deletions
+146
-130
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+7
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+15
-5
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+4
-4
nni/runtime/log.py
nni/runtime/log.py
+7
-3
nni/tools/nnictl/config_utils.py
nni/tools/nnictl/config_utils.py
+5
-5
nni/tools/nnictl/launcher.py
nni/tools/nnictl/launcher.py
+89
-22
nni/tools/nnictl/nnictl_utils.py
nni/tools/nnictl/nnictl_utils.py
+0
-35
pipelines/integration-test-adl.yml
pipelines/integration-test-adl.yml
+1
-1
pipelines/integration-test-frameworkcontroller.yml
pipelines/integration-test-frameworkcontroller.yml
+1
-1
pipelines/integration-test-kubeflow.yml
pipelines/integration-test-kubeflow.yml
+1
-1
test/config/examples/classic-nas-pytorch.yml
test/config/examples/classic-nas-pytorch.yml
+1
-1
test/config/examples/classic-nas-tf2.yml
test/config/examples/classic-nas-tf2.yml
+1
-1
test/config/integration_tests.yml
test/config/integration_tests.yml
+2
-22
test/config/integration_tests_tf2.yml
test/config/integration_tests_tf2.yml
+1
-23
test/config/metrics_test/trial.py
test/config/metrics_test/trial.py
+1
-0
test/config/tuners/regularized_evolution_tuner.yml
test/config/tuners/regularized_evolution_tuner.yml
+1
-1
test/scripts/nas.sh
test/scripts/nas.sh
+4
-4
test/ut/retiarii/test_strategy.py
test/ut/retiarii/test_strategy.py
+3
-0
ts/nni_manager/common/experimentConfig.ts
ts/nni_manager/common/experimentConfig.ts
+1
-0
ts/nni_manager/common/manager.ts
ts/nni_manager/common/manager.ts
+1
-1
No files found.
nni/retiarii/execution/interface.py
View file @
063d6b74
...
@@ -123,6 +123,13 @@ class AbstractExecutionEngine(ABC):
...
@@ -123,6 +123,13 @@ class AbstractExecutionEngine(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
budget_exhausted
(
self
)
->
bool
:
"""
Check whether user configured max trial number or max execution duration has been reached
"""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
"""
"""
...
...
nni/retiarii/experiment/pytorch.py
View file @
063d6b74
...
@@ -165,7 +165,8 @@ class RetiariiExperiment(Experiment):
...
@@ -165,7 +165,8 @@ class RetiariiExperiment(Experiment):
_logger
.
info
(
'Start strategy...'
)
_logger
.
info
(
'Start strategy...'
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
_logger
.
info
(
'Strategy exit'
)
_logger
.
info
(
'Strategy exit'
)
self
.
_dispatcher
.
mark_experiment_as_ending
()
# TODO: find out a proper way to show no more trial message on WebUI
#self._dispatcher.mark_experiment_as_ending()
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
"""
...
@@ -210,11 +211,12 @@ class RetiariiExperiment(Experiment):
...
@@ -210,11 +211,12 @@ class RetiariiExperiment(Experiment):
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
+
colorama
.
Style
.
RESET_ALL
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
+
colorama
.
Style
.
RESET_ALL
_logger
.
info
(
msg
)
_logger
.
info
(
msg
)
Thread
(
target
=
self
.
_check_exp_status
).
start
()
exp_status_checker
=
Thread
(
target
=
self
.
_check_exp_status
)
exp_status_checker
.
start
()
self
.
_start_strategy
()
self
.
_start_strategy
()
# TODO: the experiment should be completed, when strategy exits and there is no running job
# TODO: the experiment should be completed, when strategy exits and there is no running job
# _logger.info('Waiting for submitted trial jobs to finish...')
_logger
.
info
(
'Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...'
)
_logger
.
info
(
'Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...'
)
exp_status_checker
.
join
()
def
_create_dispatcher
(
self
):
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
return
self
.
_dispatcher
...
@@ -240,7 +242,12 @@ class RetiariiExperiment(Experiment):
...
@@ -240,7 +242,12 @@ class RetiariiExperiment(Experiment):
try
:
try
:
while
True
:
while
True
:
time
.
sleep
(
10
)
time
.
sleep
(
10
)
status
=
self
.
get_status
()
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
status
=
self
.
get_status
()
else
:
return
False
if
status
==
'DONE'
or
status
==
'STOPPED'
:
if
status
==
'DONE'
or
status
==
'STOPPED'
:
return
True
return
True
if
status
==
'ERROR'
:
if
status
==
'ERROR'
:
...
@@ -261,7 +268,10 @@ class RetiariiExperiment(Experiment):
...
@@ -261,7 +268,10 @@ class RetiariiExperiment(Experiment):
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
if
self
.
_proc
is
not
None
:
if
self
.
_proc
is
not
None
:
try
:
try
:
rest
.
delete
(
self
.
port
,
'/experiment'
)
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
rest
.
delete
(
self
.
port
,
'/experiment'
)
except
Exception
as
e
:
except
Exception
as
e
:
_logger
.
exception
(
e
)
_logger
.
exception
(
e
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
...
...
nni/retiarii/strategy/tpe_strategy.py
View file @
063d6b74
...
@@ -6,7 +6,7 @@ import time
...
@@ -6,7 +6,7 @@ import time
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
,
budget_exhausted
from
.base
import
BaseStrategy
from
.base
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -54,7 +54,7 @@ class TPEStrategy(BaseStrategy):
...
@@ -54,7 +54,7 @@ class TPEStrategy(BaseStrategy):
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
_logger
.
info
(
'TPE strategy has been started.'
)
_logger
.
info
(
'TPE strategy has been started.'
)
while
True
:
while
not
budget_exhausted
()
:
avail_resource
=
query_available_resources
()
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
if
avail_resource
>
0
:
model
=
base_model
model
=
base_model
...
@@ -70,13 +70,13 @@ class TPEStrategy(BaseStrategy):
...
@@ -70,13 +70,13 @@ class TPEStrategy(BaseStrategy):
else
:
else
:
time
.
sleep
(
2
)
time
.
sleep
(
2
)
_logger
.
warnin
g
(
'num of running models: %d'
,
len
(
self
.
running_models
))
_logger
.
debu
g
(
'num of running models: %d'
,
len
(
self
.
running_models
))
to_be_deleted
=
[]
to_be_deleted
=
[]
for
_id
,
_model
in
self
.
running_models
.
items
():
for
_id
,
_model
in
self
.
running_models
.
items
():
if
is_stopped_exec
(
_model
):
if
is_stopped_exec
(
_model
):
if
_model
.
metric
is
not
None
:
if
_model
.
metric
is
not
None
:
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
_logger
.
warnin
g
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
_logger
.
debu
g
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
to_be_deleted
.
append
(
_id
)
to_be_deleted
.
append
(
_id
)
for
_id
in
to_be_deleted
:
for
_id
in
to_be_deleted
:
del
self
.
running_models
[
_id
]
del
self
.
running_models
[
_id
]
nni/runtime/log.py
View file @
063d6b74
...
@@ -46,6 +46,7 @@ def init_logger() -> None:
...
@@ -46,6 +46,7 @@ def init_logger() -> None:
logging
.
getLogger
(
'filelock'
).
setLevel
(
logging
.
WARNING
)
logging
.
getLogger
(
'filelock'
).
setLevel
(
logging
.
WARNING
)
_exp_log_initialized
=
False
def
init_logger_experiment
()
->
None
:
def
init_logger_experiment
()
->
None
:
"""
"""
...
@@ -53,9 +54,12 @@ def init_logger_experiment() -> None:
...
@@ -53,9 +54,12 @@ def init_logger_experiment() -> None:
This function will get invoked after `init_logger()`.
This function will get invoked after `init_logger()`.
"""
"""
colorful_formatter
=
Formatter
(
log_format
,
time_format
)
global
_exp_log_initialized
colorful_formatter
.
format
=
_colorful_format
if
not
_exp_log_initialized
:
handlers
[
'_default_'
].
setFormatter
(
colorful_formatter
)
_exp_log_initialized
=
True
colorful_formatter
=
Formatter
(
log_format
,
time_format
)
colorful_formatter
.
format
=
_colorful_format
handlers
[
'_default_'
].
setFormatter
(
colorful_formatter
)
def
start_experiment_log
(
experiment_id
:
str
,
log_directory
:
Path
,
debug
:
bool
)
->
None
:
def
start_experiment_log
(
experiment_id
:
str
,
log_directory
:
Path
,
debug
:
bool
)
->
None
:
log_path
=
_prepare_log_dir
(
log_directory
)
/
'dispatcher.log'
log_path
=
_prepare_log_dir
(
log_directory
)
/
'dispatcher.log'
...
...
nni/tools/nnictl/config_utils.py
View file @
063d6b74
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
os
import
os
import
json
import
json
_tricks
import
shutil
import
shutil
import
sqlite3
import
sqlite3
import
time
import
time
...
@@ -92,7 +92,7 @@ class Config:
...
@@ -92,7 +92,7 @@ class Config:
'''refresh to get latest config'''
'''refresh to get latest config'''
sql
=
'select params from ExperimentProfile where id=? order by revision DESC'
sql
=
'select params from ExperimentProfile where id=? order by revision DESC'
args
=
(
self
.
experiment_id
,)
args
=
(
self
.
experiment_id
,)
self
.
config
=
config_v0_to_v1
(
json
.
loads
(
self
.
conn
.
cursor
().
execute
(
sql
,
args
).
fetchone
()[
0
]))
self
.
config
=
config_v0_to_v1
(
json
_tricks
.
loads
(
self
.
conn
.
cursor
().
execute
(
sql
,
args
).
fetchone
()[
0
]))
def
get_config
(
self
):
def
get_config
(
self
):
'''get a value according to key'''
'''get a value according to key'''
...
@@ -123,7 +123,7 @@ class Experiments:
...
@@ -123,7 +123,7 @@ class Experiments:
self
.
experiments
[
expId
][
'tag'
]
=
tag
self
.
experiments
[
expId
][
'tag'
]
=
tag
self
.
experiments
[
expId
][
'pid'
]
=
pid
self
.
experiments
[
expId
][
'pid'
]
=
pid
self
.
experiments
[
expId
][
'webuiUrl'
]
=
webuiUrl
self
.
experiments
[
expId
][
'webuiUrl'
]
=
webuiUrl
self
.
experiments
[
expId
][
'logDir'
]
=
logDir
self
.
experiments
[
expId
][
'logDir'
]
=
str
(
logDir
)
self
.
write_file
()
self
.
write_file
()
def
update_experiment
(
self
,
expId
,
key
,
value
):
def
update_experiment
(
self
,
expId
,
key
,
value
):
...
@@ -155,7 +155,7 @@ class Experiments:
...
@@ -155,7 +155,7 @@ class Experiments:
'''save config to local file'''
'''save config to local file'''
try
:
try
:
with
open
(
self
.
experiment_file
,
'w'
)
as
file
:
with
open
(
self
.
experiment_file
,
'w'
)
as
file
:
json
.
dump
(
self
.
experiments
,
file
,
indent
=
4
)
json
_tricks
.
dump
(
self
.
experiments
,
file
,
indent
=
4
)
except
IOError
as
error
:
except
IOError
as
error
:
print
(
'Error:'
,
error
)
print
(
'Error:'
,
error
)
return
''
return
''
...
@@ -165,7 +165,7 @@ class Experiments:
...
@@ -165,7 +165,7 @@ class Experiments:
if
os
.
path
.
exists
(
self
.
experiment_file
):
if
os
.
path
.
exists
(
self
.
experiment_file
):
try
:
try
:
with
open
(
self
.
experiment_file
,
'r'
)
as
file
:
with
open
(
self
.
experiment_file
,
'r'
)
as
file
:
return
json
.
load
(
file
)
return
json
_tricks
.
load
(
file
)
except
ValueError
:
except
ValueError
:
return
{}
return
{}
return
{}
return
{}
nni/tools/nnictl/launcher.py
View file @
063d6b74
...
@@ -119,12 +119,51 @@ def set_trial_config(experiment_config, port, config_file_name):
...
@@ -119,12 +119,51 @@ def set_trial_config(experiment_config, port, config_file_name):
def
set_adl_config
(
experiment_config
,
port
,
config_file_name
):
def
set_adl_config
(
experiment_config
,
port
,
config_file_name
):
'''set adl configuration'''
'''set adl configuration'''
adl_config_data
=
dict
()
# hack for supporting v2 config, need refactor
adl_config_data
[
'adl_config'
]
=
{}
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
adl_config_data
),
REST_TIME_OUT
)
err_message
=
None
if
not
response
or
not
response
.
status_code
==
200
:
if
response
is
not
None
:
err_message
=
response
.
text
_
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
set_V1_common_config
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
if
not
result
:
return
result
,
message
return
result
,
message
#set trial_config
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
None
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
None
def
validate_response
(
response
,
config_file_name
):
err_message
=
None
if
not
response
or
not
response
.
status_code
==
200
:
if
response
is
not
None
:
err_message
=
response
.
text
_
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
print_error
(
'Error:'
+
err_message
)
exit
(
1
)
# hack to fix v1 version_check and log_collection bug, need refactor
def
set_V1_common_config
(
experiment_config
,
port
,
config_file_name
):
version_check
=
True
#debug mode should disable version check
if
experiment_config
.
get
(
'debug'
)
is
not
None
:
version_check
=
not
experiment_config
.
get
(
'debug'
)
#validate version check
if
experiment_config
.
get
(
'versionCheck'
)
is
not
None
:
version_check
=
experiment_config
.
get
(
'versionCheck'
)
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
({
'version_check'
:
version_check
}),
REST_TIME_OUT
)
validate_response
(
response
,
config_file_name
)
if
experiment_config
.
get
(
'logCollection'
):
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
({
'log_collection'
:
experiment_config
.
get
(
'logCollection'
)}),
REST_TIME_OUT
)
validate_response
(
response
,
config_file_name
)
def
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
):
def
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
):
'''set nniManagerIp'''
'''set nniManagerIp'''
if
experiment_config
.
get
(
'nniManagerIp'
)
is
None
:
if
experiment_config
.
get
(
'nniManagerIp'
)
is
None
:
...
@@ -155,6 +194,7 @@ def set_kubeflow_config(experiment_config, port, config_file_name):
...
@@ -155,6 +194,7 @@ def set_kubeflow_config(experiment_config, port, config_file_name):
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
return
False
,
err_message
set_V1_common_config
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
if
not
result
:
return
result
,
message
return
result
,
message
...
@@ -174,6 +214,7 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name):
...
@@ -174,6 +214,7 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name):
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
return
False
,
err_message
set_V1_common_config
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
if
not
result
:
return
result
,
message
return
result
,
message
...
@@ -200,9 +241,13 @@ def set_experiment_v1(experiment_config, mode, port, config_file_name):
...
@@ -200,9 +241,13 @@ def set_experiment_v1(experiment_config, mode, port, config_file_name):
request_data
[
'experimentName'
]
=
experiment_config
[
'experimentName'
]
request_data
[
'experimentName'
]
=
experiment_config
[
'experimentName'
]
request_data
[
'trialConcurrency'
]
=
experiment_config
[
'trialConcurrency'
]
request_data
[
'trialConcurrency'
]
=
experiment_config
[
'trialConcurrency'
]
request_data
[
'maxExecDuration'
]
=
experiment_config
[
'maxExecDuration'
]
request_data
[
'maxExecDuration'
]
=
experiment_config
[
'maxExecDuration'
]
request_data
[
'maxExperimentDuration'
]
=
str
(
experiment_config
[
'maxExecDuration'
])
+
's'
request_data
[
'maxTrialNum'
]
=
experiment_config
[
'maxTrialNum'
]
request_data
[
'maxTrialNum'
]
=
experiment_config
[
'maxTrialNum'
]
request_data
[
'maxTrialNumber'
]
=
experiment_config
[
'maxTrialNum'
]
request_data
[
'searchSpace'
]
=
experiment_config
.
get
(
'searchSpace'
)
request_data
[
'searchSpace'
]
=
experiment_config
.
get
(
'searchSpace'
)
request_data
[
'trainingServicePlatform'
]
=
experiment_config
.
get
(
'trainingServicePlatform'
)
request_data
[
'trainingServicePlatform'
]
=
experiment_config
.
get
(
'trainingServicePlatform'
)
# hack for hotfix, fix config.trainingService undefined error, need refactor
request_data
[
'trainingService'
]
=
{
'platform'
:
experiment_config
.
get
(
'trainingServicePlatform'
)}
if
experiment_config
.
get
(
'description'
):
if
experiment_config
.
get
(
'description'
):
request_data
[
'description'
]
=
experiment_config
[
'description'
]
request_data
[
'description'
]
=
experiment_config
[
'description'
]
if
experiment_config
.
get
(
'multiPhase'
):
if
experiment_config
.
get
(
'multiPhase'
):
...
@@ -319,7 +364,10 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
...
@@ -319,7 +364,10 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
if
package_name
in
[
'SMAC'
,
'BOHB'
,
'PPOTuner'
]:
if
package_name
in
[
'SMAC'
,
'BOHB'
,
'PPOTuner'
]:
print_error
(
f
'The dependencies for
{
package_name
}
can be installed through pip install nni[
{
package_name
}
]'
)
print_error
(
f
'The dependencies for
{
package_name
}
can be installed through pip install nni[
{
package_name
}
]'
)
raise
raise
log_dir
=
experiment_config
[
'logDir'
]
if
experiment_config
.
get
(
'logDir'
)
else
NNI_HOME_DIR
if
config_version
==
1
:
log_dir
=
experiment_config
[
'logDir'
]
if
experiment_config
.
get
(
'logDir'
)
else
NNI_HOME_DIR
else
:
log_dir
=
experiment_config
[
'experimentWorkingDirectory'
]
if
experiment_config
.
get
(
'experimentWorkingDirectory'
)
else
NNI_HOME_DIR
log_level
=
experiment_config
[
'logLevel'
]
if
experiment_config
.
get
(
'logLevel'
)
else
None
log_level
=
experiment_config
[
'logLevel'
]
if
experiment_config
.
get
(
'logLevel'
)
else
None
#view experiment mode do not need debug function, when view an experiment, there will be no new logs created
#view experiment mode do not need debug function, when view an experiment, there will be no new logs created
foreground
=
False
foreground
=
False
...
@@ -330,6 +378,8 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
...
@@ -330,6 +378,8 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
# start rest server
# start rest server
if
config_version
==
1
:
if
config_version
==
1
:
platform
=
experiment_config
[
'trainingServicePlatform'
]
platform
=
experiment_config
[
'trainingServicePlatform'
]
elif
isinstance
(
experiment_config
[
'trainingService'
],
list
):
platform
=
'hybrid'
else
:
else
:
platform
=
experiment_config
[
'trainingService'
][
'platform'
]
platform
=
experiment_config
[
'trainingService'
][
'platform'
]
...
@@ -349,14 +399,14 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
...
@@ -349,14 +399,14 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
code_dir
=
expand_annotations
(
experiment_config
[
'trial'
][
'codeDir'
],
path
,
nas_mode
=
nas_mode
)
code_dir
=
expand_annotations
(
experiment_config
[
'trial'
][
'codeDir'
],
path
,
nas_mode
=
nas_mode
)
experiment_config
[
'trial'
][
'codeDir'
]
=
code_dir
experiment_config
[
'trial'
][
'codeDir'
]
=
code_dir
search_space
=
generate_search_space
(
code_dir
)
search_space
=
generate_search_space
(
code_dir
)
experiment_config
[
'searchSpace'
]
=
json
.
dumps
(
search_space
)
experiment_config
[
'searchSpace'
]
=
search_space
assert
search_space
,
ERROR_INFO
%
'Generated search space is empty'
assert
search_space
,
ERROR_INFO
%
'Generated search space is empty'
elif
config_version
==
1
:
elif
config_version
==
1
:
if
experiment_config
.
get
(
'searchSpacePath'
):
if
experiment_config
.
get
(
'searchSpacePath'
):
search_space
=
get_json_content
(
experiment_config
.
get
(
'searchSpacePath'
))
search_space
=
get_json_content
(
experiment_config
.
get
(
'searchSpacePath'
))
experiment_config
[
'searchSpace'
]
=
json
.
dumps
(
search_space
)
experiment_config
[
'searchSpace'
]
=
search_space
else
:
else
:
experiment_config
[
'searchSpace'
]
=
json
.
dumps
(
''
)
experiment_config
[
'searchSpace'
]
=
''
# check rest server
# check rest server
running
,
_
=
check_rest_server
(
args
.
port
)
running
,
_
=
check_rest_server
(
args
.
port
)
...
@@ -411,6 +461,21 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
...
@@ -411,6 +461,21 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
kill_command
(
rest_process
.
pid
)
kill_command
(
rest_process
.
pid
)
print_normal
(
'Stopping experiment...'
)
print_normal
(
'Stopping experiment...'
)
def
_validate_v1
(
config
,
path
):
try
:
validate_all_content
(
config
,
path
)
except
Exception
as
e
:
print_error
(
f
'Config V1 validation failed:
{
repr
(
e
)
}
'
)
exit
(
1
)
def
_validate_v2
(
config
,
path
):
base_path
=
Path
(
path
).
parent
try
:
conf
=
ExperimentConfig
(
_base_path
=
base_path
,
**
config
)
return
conf
.
json
()
except
Exception
as
e
:
print_error
(
f
'Config V2 validation failed:
{
repr
(
e
)
}
'
)
def
create_experiment
(
args
):
def
create_experiment
(
args
):
'''start a new experiment'''
'''start a new experiment'''
experiment_id
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
8
))
experiment_id
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
8
))
...
@@ -420,23 +485,23 @@ def create_experiment(args):
...
@@ -420,23 +485,23 @@ def create_experiment(args):
exit
(
1
)
exit
(
1
)
config_yml
=
get_yml_content
(
config_path
)
config_yml
=
get_yml_content
(
config_path
)
try
:
if
'trainingServicePlatform'
in
config_yml
:
config
=
ExperimentConfig
(
_base_path
=
Path
(
config_path
).
parent
,
**
config_
yml
)
_validate_v1
(
config_yml
,
config_
path
)
config_v2
=
config
.
json
()
platform
=
config_yml
[
'trainingServicePlatform'
]
except
Exception
as
error_v2
:
if
platform
in
k8s_training_services
:
print_warning
(
'Validation with V2 schema failed. Trying to convert from V1 format...'
)
schema
=
1
try
:
config_v1
=
config_yml
validate_all_content
(
config_yml
,
config_path
)
else
:
except
Exception
as
error_v1
:
schema
=
2
print_error
(
f
'Convert from v1 format failed:
{
repr
(
error_v1
)
}
'
)
from
nni.experiment.config
import
convert
print_error
(
f
'Config in v2 format validation failed:
{
repr
(
error_v2
)
}
'
)
config_v2
=
convert
.
to_v2
(
config_yml
).
json
(
)
exit
(
1
)
else
:
from
nni.experiment.config
import
convert
config_v2
=
_validate_v2
(
config_yml
,
config_path
)
config_v2
=
convert
.
to_v2
(
config_yml
).
json
()
schema
=
2
try
:
try
:
if
getattr
(
config_v2
[
'trainingService'
],
'platform'
,
None
)
in
k8s_training_services
:
if
schema
==
1
:
launch_experiment
(
args
,
config_
yml
,
'new'
,
experiment_id
,
1
)
launch_experiment
(
args
,
config_
v1
,
'new'
,
experiment_id
,
1
)
else
:
else
:
launch_experiment
(
args
,
config_v2
,
'new'
,
experiment_id
,
2
)
launch_experiment
(
args
,
config_v2
,
'new'
,
experiment_id
,
2
)
except
Exception
as
exception
:
except
Exception
as
exception
:
...
@@ -470,10 +535,12 @@ def manage_stopped_experiment(args, mode):
...
@@ -470,10 +535,12 @@ def manage_stopped_experiment(args, mode):
experiments_config
.
update_experiment
(
args
.
id
,
'port'
,
args
.
port
)
experiments_config
.
update_experiment
(
args
.
id
,
'port'
,
args
.
port
)
assert
'trainingService'
in
experiment_config
or
'trainingServicePlatform'
in
experiment_config
assert
'trainingService'
in
experiment_config
or
'trainingServicePlatform'
in
experiment_config
try
:
try
:
if
'trainingService'
in
experiment_config
:
if
'trainingServicePlatform'
in
experiment_config
:
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
,
2
)
experiment_config
[
'logDir'
]
=
experiments_dict
[
args
.
id
][
'logDir'
]
else
:
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
,
1
)
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
,
1
)
else
:
experiment_config
[
'experimentWorkingDirectory'
]
=
experiments_dict
[
args
.
id
][
'logDir'
]
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
,
2
)
except
Exception
as
exception
:
except
Exception
as
exception
:
restServerPid
=
Experiments
().
get_all_experiments
().
get
(
experiment_id
,
{}).
get
(
'pid'
)
restServerPid
=
Experiments
().
get_all_experiments
().
get
(
experiment_id
,
{}).
get
(
'pid'
)
if
restServerPid
:
if
restServerPid
:
...
...
nni/tools/nnictl/nnictl_utils.py
View file @
063d6b74
...
@@ -13,7 +13,6 @@ from functools import cmp_to_key
...
@@ -13,7 +13,6 @@ from functools import cmp_to_key
import
traceback
import
traceback
from
datetime
import
datetime
,
timezone
from
datetime
import
datetime
,
timezone
from
subprocess
import
Popen
from
subprocess
import
Popen
from
pyhdfs
import
HdfsClient
from
nni.tools.annotation
import
expand_annotations
from
nni.tools.annotation
import
expand_annotations
import
nni_node
# pylint: disable=import-error
import
nni_node
# pylint: disable=import-error
from
.rest_utils
import
rest_get
,
rest_delete
,
check_rest_server_quick
,
check_response
from
.rest_utils
import
rest_get
,
rest_delete
,
check_rest_server_quick
,
check_response
...
@@ -501,30 +500,6 @@ def remote_clean(machine_list, experiment_id=None):
...
@@ -501,30 +500,6 @@ def remote_clean(machine_list, experiment_id=None):
print_normal
(
'removing folder {0}'
.
format
(
host
+
':'
+
str
(
port
)
+
remote_dir
))
print_normal
(
'removing folder {0}'
.
format
(
host
+
':'
+
str
(
port
)
+
remote_dir
))
remove_remote_directory
(
sftp
,
remote_dir
)
remove_remote_directory
(
sftp
,
remote_dir
)
def
hdfs_clean
(
host
,
user_name
,
output_dir
,
experiment_id
=
None
):
'''clean up hdfs data'''
hdfs_client
=
HdfsClient
(
hosts
=
'{0}:80'
.
format
(
host
),
user_name
=
user_name
,
webhdfs_path
=
'/webhdfs/api/v1'
,
timeout
=
5
)
if
experiment_id
:
full_path
=
'/'
+
'/'
.
join
([
user_name
,
'nni'
,
'experiments'
,
experiment_id
])
else
:
full_path
=
'/'
+
'/'
.
join
([
user_name
,
'nni'
,
'experiments'
])
print_normal
(
'removing folder {0} in hdfs'
.
format
(
full_path
))
hdfs_client
.
delete
(
full_path
,
recursive
=
True
)
if
output_dir
:
pattern
=
re
.
compile
(
'hdfs://(?P<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?P<baseDir>/.*)?'
)
match_result
=
pattern
.
match
(
output_dir
)
if
match_result
:
output_host
=
match_result
.
group
(
'host'
)
output_dir
=
match_result
.
group
(
'baseDir'
)
#check if the host is valid
if
output_host
!=
host
:
print_warning
(
'The host in {0} is not consistent with {1}'
.
format
(
output_dir
,
host
))
else
:
if
experiment_id
:
output_dir
=
output_dir
+
'/'
+
experiment_id
print_normal
(
'removing folder {0} in hdfs'
.
format
(
output_dir
))
hdfs_client
.
delete
(
output_dir
,
recursive
=
True
)
def
experiment_clean
(
args
):
def
experiment_clean
(
args
):
'''clean up the experiment data'''
'''clean up the experiment data'''
experiment_id_list
=
[]
experiment_id_list
=
[]
...
@@ -556,11 +531,6 @@ def experiment_clean(args):
...
@@ -556,11 +531,6 @@ def experiment_clean(args):
if
platform
==
'remote'
:
if
platform
==
'remote'
:
machine_list
=
experiment_config
.
get
(
'machineList'
)
machine_list
=
experiment_config
.
get
(
'machineList'
)
remote_clean
(
machine_list
,
experiment_id
)
remote_clean
(
machine_list
,
experiment_id
)
elif
platform
==
'pai'
:
host
=
experiment_config
.
get
(
'paiConfig'
).
get
(
'host'
)
user_name
=
experiment_config
.
get
(
'paiConfig'
).
get
(
'userName'
)
output_dir
=
experiment_config
.
get
(
'trial'
).
get
(
'outputDir'
)
hdfs_clean
(
host
,
user_name
,
output_dir
,
experiment_id
)
elif
platform
!=
'local'
:
elif
platform
!=
'local'
:
# TODO: support all platforms
# TODO: support all platforms
print_warning
(
'platform {0} clean up not supported yet.'
.
format
(
platform
))
print_warning
(
'platform {0} clean up not supported yet.'
.
format
(
platform
))
...
@@ -632,11 +602,6 @@ def platform_clean(args):
...
@@ -632,11 +602,6 @@ def platform_clean(args):
if
platform
==
'remote'
:
if
platform
==
'remote'
:
machine_list
=
config_content
.
get
(
'machineList'
)
machine_list
=
config_content
.
get
(
'machineList'
)
remote_clean
(
machine_list
)
remote_clean
(
machine_list
)
elif
platform
==
'pai'
:
host
=
config_content
.
get
(
'paiConfig'
).
get
(
'host'
)
user_name
=
config_content
.
get
(
'paiConfig'
).
get
(
'userName'
)
output_dir
=
config_content
.
get
(
'trial'
).
get
(
'outputDir'
)
hdfs_clean
(
host
,
user_name
,
output_dir
)
print_normal
(
'Done.'
)
print_normal
(
'Done.'
)
def
experiment_list
(
args
):
def
experiment_list
(
args
):
...
...
pipelines/integration-test-adl.yml
View file @
063d6b74
...
@@ -59,5 +59,5 @@ jobs:
...
@@ -59,5 +59,5 @@ jobs:
--checkpoint_storage_class $(checkpoint_storage_class) \
--checkpoint_storage_class $(checkpoint_storage_class) \
--checkpoint_storage_size $(checkpoint_storage_size) \
--checkpoint_storage_size $(checkpoint_storage_size) \
--nni_manager_ip $(nni_manager_ip)
--nni_manager_ip $(nni_manager_ip)
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts adl
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts adl
--exclude multi-phase,multi-thread
displayName
:
Integration test
displayName
:
Integration test
pipelines/integration-test-frameworkcontroller.yml
View file @
063d6b74
...
@@ -48,5 +48,5 @@ jobs:
...
@@ -48,5 +48,5 @@ jobs:
--azs_share nni \
--azs_share nni \
--nni_docker_image nnidev/nni-nightly \
--nni_docker_image nnidev/nni-nightly \
--nni_manager_ip $(manager_ip)
--nni_manager_ip $(manager_ip)
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts frameworkcontroller --exclude multi-phase
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts frameworkcontroller --exclude multi-phase
,multi-thread
displayName
:
Integration test
displayName
:
Integration test
pipelines/integration-test-kubeflow.yml
View file @
063d6b74
...
@@ -58,5 +58,5 @@ jobs:
...
@@ -58,5 +58,5 @@ jobs:
--azs_share nni \
--azs_share nni \
--nni_docker_image nnidev/nni-nightly \
--nni_docker_image nnidev/nni-nightly \
--nni_manager_ip $(manager_ip)
--nni_manager_ip $(manager_ip)
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts kubeflow --exclude multi-phase
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts kubeflow --exclude multi-phase
,multi-thread
displayName
:
Integration test
displayName
:
Integration test
test/config/examples/classic-nas-pytorch.yml
View file @
063d6b74
...
@@ -11,7 +11,7 @@ tuner:
...
@@ -11,7 +11,7 @@ tuner:
optimize_mode
:
maximize
optimize_mode
:
maximize
trial
:
trial
:
command
:
python3 mnist.py --epochs
1
command
:
python3 mnist.py --epochs
1
codeDir
:
../../../examples/nas/classic_nas
codeDir
:
../../../examples/nas/
legacy/
classic_nas
gpuNum
:
0
gpuNum
:
0
useAnnotation
:
false
useAnnotation
:
false
...
...
test/config/examples/classic-nas-tf2.yml
View file @
063d6b74
...
@@ -11,7 +11,7 @@ tuner:
...
@@ -11,7 +11,7 @@ tuner:
optimize_mode
:
maximize
optimize_mode
:
maximize
trial
:
trial
:
command
:
python3 train.py --epochs
1
command
:
python3 train.py --epochs
1
codeDir
:
../../../examples/nas/classic_nas-tf
codeDir
:
../../../examples/nas/
legacy/
classic_nas-tf
gpuNum
:
0
gpuNum
:
0
useAnnotation
:
false
useAnnotation
:
false
...
...
test/config/integration_tests.yml
View file @
063d6b74
...
@@ -40,6 +40,7 @@ testCases:
...
@@ -40,6 +40,7 @@ testCases:
-
name
:
mnist-tensorflow
-
name
:
mnist-tensorflow
configFile
:
test/config/examples/mnist-tfv2.yml
configFile
:
test/config/examples/mnist-tfv2.yml
config
:
config
:
maxExecDuration
:
10m
# This example will use longger time in remote mode, set max_duration to 10m to avoid timeout error.
maxTrialNum
:
1
maxTrialNum
:
1
trialConcurrency
:
1
trialConcurrency
:
1
trainingService
:
local remote
# FIXME: timeout on pai, looks like tensorflow failed to link CUDA
trainingService
:
local remote
# FIXME: timeout on pai, looks like tensorflow failed to link CUDA
...
@@ -84,7 +85,7 @@ testCases:
...
@@ -84,7 +85,7 @@ testCases:
-
name
:
classic-nas-gen-ss
-
name
:
classic-nas-gen-ss
configFile
:
test/config/examples/classic-nas-pytorch.yml
configFile
:
test/config/examples/classic-nas-pytorch.yml
launchCommand
:
nnictl ss_gen --trial_command="python3 mnist.py --epochs 1" --trial_dir=../examples/nas/classic_nas --file=config/examples/nni-nas-search-space.json
launchCommand
:
nnictl ss_gen --trial_command="python3 mnist.py --epochs 1" --trial_dir=../examples/nas/
legacy/
classic_nas --file=config/examples/nni-nas-search-space.json
stopCommand
:
stopCommand
:
experimentStatusCheck
:
False
experimentStatusCheck
:
False
trainingService
:
local
trainingService
:
local
...
@@ -170,27 +171,6 @@ testCases:
...
@@ -170,27 +171,6 @@ testCases:
-
name
:
multi-thread
-
name
:
multi-thread
configFile
:
test/config/multi_thread/config.yml
configFile
:
test/config/multi_thread/config.yml
-
name
:
multi-phase-batch
configFile
:
test/config/multi_phase/batch.yml
config
:
# for batch tuner, maxTrialNum can not exceed length of search space
maxTrialNum
:
2
trialConcurrency
:
2
-
name
:
multi-phase-evolution
configFile
:
test/config/multi_phase/evolution.yml
-
name
:
multi-phase-grid
configFile
:
test/config/multi_phase/grid.yml
config
:
maxTrialNum
:
2
trialConcurrency
:
2
-
name
:
multi-phase-metis
configFile
:
test/config/multi_phase/metis.yml
-
name
:
multi-phase-tpe
configFile
:
test/config/multi_phase/tpe.yml
#########################################################################
#########################################################################
# nni assessor test
# nni assessor test
...
...
test/config/integration_tests_tf2.yml
View file @
063d6b74
...
@@ -58,7 +58,7 @@ testCases:
...
@@ -58,7 +58,7 @@ testCases:
-
name
:
classic-nas-gen-ss
-
name
:
classic-nas-gen-ss
configFile
:
test/config/examples/classic-nas-tf2.yml
configFile
:
test/config/examples/classic-nas-tf2.yml
launchCommand
:
nnictl ss_gen --trial_command="python3 train.py --epochs 1" --trial_dir=../examples/nas/classic_nas-tf --file=config/examples/nni-nas-search-space-tf2.json
launchCommand
:
nnictl ss_gen --trial_command="python3 train.py --epochs 1" --trial_dir=../examples/nas/
legacy/
classic_nas-tf --file=config/examples/nni-nas-search-space-tf2.json
stopCommand
:
stopCommand
:
experimentStatusCheck
:
False
experimentStatusCheck
:
False
trainingService
:
local
trainingService
:
local
...
@@ -135,28 +135,6 @@ testCases:
...
@@ -135,28 +135,6 @@ testCases:
-
name
:
multi-thread
-
name
:
multi-thread
configFile
:
test/config/multi_thread/config.yml
configFile
:
test/config/multi_thread/config.yml
-
name
:
multi-phase-batch
configFile
:
test/config/multi_phase/batch.yml
config
:
# for batch tuner, maxTrialNum can not exceed length of search space
maxTrialNum
:
2
trialConcurrency
:
2
-
name
:
multi-phase-evolution
configFile
:
test/config/multi_phase/evolution.yml
-
name
:
multi-phase-grid
configFile
:
test/config/multi_phase/grid.yml
config
:
maxTrialNum
:
2
trialConcurrency
:
2
-
name
:
multi-phase-metis
configFile
:
test/config/multi_phase/metis.yml
-
name
:
multi-phase-tpe
configFile
:
test/config/multi_phase/tpe.yml
#########################################################################
#########################################################################
# nni assessor test
# nni assessor test
#########################################################################
#########################################################################
...
...
test/config/metrics_test/trial.py
View file @
063d6b74
...
@@ -19,6 +19,7 @@ if __name__ == '__main__':
...
@@ -19,6 +19,7 @@ if __name__ == '__main__':
nni
.
get_next_parameter
()
nni
.
get_next_parameter
()
with
open
(
result_file
,
'r'
)
as
f
:
with
open
(
result_file
,
'r'
)
as
f
:
m
=
json
.
load
(
f
)
m
=
json
.
load
(
f
)
time
.
sleep
(
5
)
for
v
in
m
[
'intermediate_result'
]:
for
v
in
m
[
'intermediate_result'
]:
time
.
sleep
(
1
)
time
.
sleep
(
1
)
print
(
'report_intermediate_result:'
,
v
)
print
(
'report_intermediate_result:'
,
v
)
...
...
test/config/tuners/regularized_evolution_tuner.yml
View file @
063d6b74
...
@@ -9,7 +9,7 @@ tuner:
...
@@ -9,7 +9,7 @@ tuner:
classArgs
:
classArgs
:
optimize_mode
:
maximize
optimize_mode
:
maximize
trial
:
trial
:
codeDir
:
../../../examples/nas/classic_nas
codeDir
:
../../../examples/nas/
legacy/
classic_nas
command
:
python3 mnist.py --epochs
1
command
:
python3 mnist.py --epochs
1
gpuNum
:
0
gpuNum
:
0
...
...
test/scripts/nas.sh
View file @
063d6b74
...
@@ -7,7 +7,7 @@ echo "===========================Testing: NAS==========================="
...
@@ -7,7 +7,7 @@ echo "===========================Testing: NAS==========================="
EXAMPLE_DIR
=
${
CWD
}
/../examples/nas
EXAMPLE_DIR
=
${
CWD
}
/../examples/nas
echo
"testing nnictl ss_gen (classic nas)..."
echo
"testing nnictl ss_gen (classic nas)..."
cd
$EXAMPLE_DIR
/classic_nas
cd
$EXAMPLE_DIR
/
legacy/
classic_nas
SEARCH_SPACE_JSON
=
nni_auto_gen_search_space.json
SEARCH_SPACE_JSON
=
nni_auto_gen_search_space.json
if
[
-f
$SEARCH_SPACE_JSON
]
;
then
if
[
-f
$SEARCH_SPACE_JSON
]
;
then
rm
$SEARCH_SPACE_JSON
rm
$SEARCH_SPACE_JSON
...
@@ -19,12 +19,12 @@ if [ ! -f $SEARCH_SPACE_JSON ]; then
...
@@ -19,12 +19,12 @@ if [ ! -f $SEARCH_SPACE_JSON ]; then
fi
fi
echo
"testing darts..."
echo
"testing darts..."
cd
$EXAMPLE_DIR
/darts
cd
$EXAMPLE_DIR
/
oneshot/
darts
python3 search.py
--epochs
1
--channels
2
--layers
4
python3 search.py
--epochs
1
--channels
2
--layers
4
python3 retrain.py
--arc-checkpoint
./checkpoint.json
--layers
4
--epochs
1
python3 retrain.py
--arc-checkpoint
./checkpoint.json
--layers
4
--epochs
1
echo
"testing enas..."
echo
"testing enas..."
cd
$EXAMPLE_DIR
/enas
cd
$EXAMPLE_DIR
/
oneshot/
enas
python3 search.py
--search-for
macro
--epochs
1
python3 search.py
--search-for
macro
--epochs
1
python3 search.py
--search-for
micro
--epochs
1
python3 search.py
--search-for
micro
--epochs
1
...
@@ -34,5 +34,5 @@ python3 search.py --search-for micro --epochs 1
...
@@ -34,5 +34,5 @@ python3 search.py --search-for micro --epochs 1
#python3 train.py
#python3 train.py
echo
"testing pdarts..."
echo
"testing pdarts..."
cd
$EXAMPLE_DIR
/pdarts
cd
$EXAMPLE_DIR
/
legacy/
pdarts
python3 search.py
--epochs
1
--channels
4
--nodes
2
--log-frequency
10
--add_layers
0
--add_layers
1
--dropped_ops
3
--dropped_ops
3
python3 search.py
--epochs
1
--channels
4
--nodes
2
--log-frequency
10
--add_layers
0
--add_layers
1
--dropped_ops
3
--dropped_ops
3
test/ut/retiarii/test_strategy.py
View file @
063d6b74
...
@@ -43,6 +43,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
...
@@ -43,6 +43,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
return
self
.
_resource_left
return
self
.
_resource_left
def
budget_exhausted
(
self
)
->
bool
:
pass
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
pass
pass
...
...
ts/nni_manager/common/experimentConfig.ts
View file @
063d6b74
...
@@ -65,6 +65,7 @@ export interface AmlConfig extends TrainingServiceConfig {
...
@@ -65,6 +65,7 @@ export interface AmlConfig extends TrainingServiceConfig {
workspaceName
:
string
;
workspaceName
:
string
;
computeTarget
:
string
;
computeTarget
:
string
;
dockerImage
:
string
;
dockerImage
:
string
;
maxTrialNumberPerGpu
:
number
;
}
}
/* Kubeflow */
/* Kubeflow */
...
...
ts/nni_manager/common/manager.ts
View file @
063d6b74
...
@@ -8,7 +8,7 @@ import { TrialJobStatus, LogType } from './trainingService';
...
@@ -8,7 +8,7 @@ import { TrialJobStatus, LogType } from './trainingService';
import
{
ExperimentConfig
}
from
'
./experimentConfig
'
;
import
{
ExperimentConfig
}
from
'
./experimentConfig
'
;
type
ProfileUpdateType
=
'
TRIAL_CONCURRENCY
'
|
'
MAX_EXEC_DURATION
'
|
'
SEARCH_SPACE
'
|
'
MAX_TRIAL_NUM
'
;
type
ProfileUpdateType
=
'
TRIAL_CONCURRENCY
'
|
'
MAX_EXEC_DURATION
'
|
'
SEARCH_SPACE
'
|
'
MAX_TRIAL_NUM
'
;
type
ExperimentStatus
=
'
INITIALIZED
'
|
'
RUNNING
'
|
'
ERROR
'
|
'
STOPPING
'
|
'
STOPPED
'
|
'
DONE
'
|
'
NO_MORE_TRIAL
'
|
'
TUNER_NO_MORE_TRIAL
'
;
type
ExperimentStatus
=
'
INITIALIZED
'
|
'
RUNNING
'
|
'
ERROR
'
|
'
STOPPING
'
|
'
STOPPED
'
|
'
DONE
'
|
'
NO_MORE_TRIAL
'
|
'
TUNER_NO_MORE_TRIAL
'
|
'
VIEWED
'
;
namespace
ExperimentStartUpMode
{
namespace
ExperimentStartUpMode
{
export
const
NEW
=
'
new
'
;
export
const
NEW
=
'
new
'
;
export
const
RESUME
=
'
resume
'
;
export
const
RESUME
=
'
resume
'
;
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment