Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c8ef4141
Commit
c8ef4141
authored
Apr 19, 2019
by
Zejun Lin
Committed by
SparkSnail
Apr 19, 2019
Browse files
Implement API for user to import data and export data of type `json` or `csv` (#980)
parent
1d9b0a99
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
171 additions
and
32 deletions
+171
-32
src/nni_manager/common/datastore.ts
src/nni_manager/common/datastore.ts
+1
-1
src/nni_manager/common/manager.ts
src/nni_manager/common/manager.ts
+1
-0
src/nni_manager/core/commands.ts
src/nni_manager/core/commands.ts
+3
-0
src/nni_manager/core/nnimanager.ts
src/nni_manager/core/nnimanager.ts
+12
-1
src/nni_manager/rest_server/restHandler.ts
src/nni_manager/rest_server/restHandler.ts
+11
-0
src/nni_manager/rest_server/test/mockedNNIManager.ts
src/nni_manager/rest_server/test/mockedNNIManager.ts
+3
-0
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+14
-8
src/sdk/pynni/nni/msg_dispatcher_base.py
src/sdk/pynni/nni/msg_dispatcher_base.py
+4
-0
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
+7
-0
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
+6
-0
src/sdk/pynni/nni/protocol.py
src/sdk/pynni/nni/protocol.py
+1
-0
src/sdk/pynni/nni/tuner.py
src/sdk/pynni/nni/tuner.py
+6
-0
tools/nni_cmd/constants.py
tools/nni_cmd/constants.py
+14
-0
tools/nni_cmd/nnictl.py
tools/nni_cmd/nnictl.py
+12
-5
tools/nni_cmd/nnictl_utils.py
tools/nni_cmd/nnictl_utils.py
+13
-4
tools/nni_cmd/updater.py
tools/nni_cmd/updater.py
+56
-13
tools/nni_cmd/url_utils.py
tools/nni_cmd/url_utils.py
+7
-0
No files found.
src/nni_manager/common/datastore.ts
View file @
c8ef4141
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
import
{
ExperimentProfile
,
TrialJobStatistics
}
from
'
./manager
'
;
import
{
ExperimentProfile
,
TrialJobStatistics
}
from
'
./manager
'
;
import
{
TrialJobDetail
,
TrialJobStatus
}
from
'
./trainingService
'
;
import
{
TrialJobDetail
,
TrialJobStatus
}
from
'
./trainingService
'
;
type
TrialJobEvent
=
TrialJobStatus
|
'
USER_TO_CANCEL
'
|
'
ADD_CUSTOMIZED
'
|
'
ADD_HYPERPARAMETER
'
;
type
TrialJobEvent
=
TrialJobStatus
|
'
USER_TO_CANCEL
'
|
'
ADD_CUSTOMIZED
'
|
'
ADD_HYPERPARAMETER
'
|
'
IMPORT_DATA
'
;
type
MetricType
=
'
PERIODICAL
'
|
'
FINAL
'
|
'
CUSTOM
'
|
'
REQUEST_PARAMETER
'
;
type
MetricType
=
'
PERIODICAL
'
|
'
FINAL
'
|
'
CUSTOM
'
|
'
REQUEST_PARAMETER
'
;
interface
ExperimentProfileRecord
{
interface
ExperimentProfileRecord
{
...
...
src/nni_manager/common/manager.ts
View file @
c8ef4141
...
@@ -99,6 +99,7 @@ abstract class Manager {
...
@@ -99,6 +99,7 @@ abstract class Manager {
public
abstract
stopExperiment
():
Promise
<
void
>
;
public
abstract
stopExperiment
():
Promise
<
void
>
;
public
abstract
getExperimentProfile
():
Promise
<
ExperimentProfile
>
;
public
abstract
getExperimentProfile
():
Promise
<
ExperimentProfile
>
;
public
abstract
updateExperimentProfile
(
experimentProfile
:
ExperimentProfile
,
updateType
:
ProfileUpdateType
):
Promise
<
void
>
;
public
abstract
updateExperimentProfile
(
experimentProfile
:
ExperimentProfile
,
updateType
:
ProfileUpdateType
):
Promise
<
void
>
;
public
abstract
importData
(
data
:
string
):
Promise
<
void
>
;
public
abstract
addCustomizedTrialJob
(
hyperParams
:
string
):
Promise
<
void
>
;
public
abstract
addCustomizedTrialJob
(
hyperParams
:
string
):
Promise
<
void
>
;
public
abstract
cancelTrialJobByUser
(
trialJobId
:
string
):
Promise
<
void
>
;
public
abstract
cancelTrialJobByUser
(
trialJobId
:
string
):
Promise
<
void
>
;
...
...
src/nni_manager/core/commands.ts
View file @
c8ef4141
...
@@ -22,6 +22,7 @@ const INITIALIZE = 'IN';
...
@@ -22,6 +22,7 @@ const INITIALIZE = 'IN';
const
REQUEST_TRIAL_JOBS
=
'
GE
'
;
const
REQUEST_TRIAL_JOBS
=
'
GE
'
;
const
REPORT_METRIC_DATA
=
'
ME
'
;
const
REPORT_METRIC_DATA
=
'
ME
'
;
const
UPDATE_SEARCH_SPACE
=
'
SS
'
;
const
UPDATE_SEARCH_SPACE
=
'
SS
'
;
const
IMPORT_DATA
=
'
FD
'
const
ADD_CUSTOMIZED_TRIAL_JOB
=
'
AD
'
;
const
ADD_CUSTOMIZED_TRIAL_JOB
=
'
AD
'
;
const
TRIAL_END
=
'
EN
'
;
const
TRIAL_END
=
'
EN
'
;
const
TERMINATE
=
'
TE
'
;
const
TERMINATE
=
'
TE
'
;
...
@@ -38,6 +39,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
...
@@ -38,6 +39,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
REQUEST_TRIAL_JOBS
,
REQUEST_TRIAL_JOBS
,
REPORT_METRIC_DATA
,
REPORT_METRIC_DATA
,
UPDATE_SEARCH_SPACE
,
UPDATE_SEARCH_SPACE
,
IMPORT_DATA
,
ADD_CUSTOMIZED_TRIAL_JOB
,
ADD_CUSTOMIZED_TRIAL_JOB
,
TERMINATE
,
TERMINATE
,
PING
,
PING
,
...
@@ -62,6 +64,7 @@ export {
...
@@ -62,6 +64,7 @@ export {
REQUEST_TRIAL_JOBS
,
REQUEST_TRIAL_JOBS
,
REPORT_METRIC_DATA
,
REPORT_METRIC_DATA
,
UPDATE_SEARCH_SPACE
,
UPDATE_SEARCH_SPACE
,
IMPORT_DATA
,
ADD_CUSTOMIZED_TRIAL_JOB
,
ADD_CUSTOMIZED_TRIAL_JOB
,
TRIAL_END
,
TRIAL_END
,
TERMINATE
,
TERMINATE
,
...
...
src/nni_manager/core/nnimanager.ts
View file @
c8ef4141
...
@@ -38,7 +38,7 @@ import {
...
@@ -38,7 +38,7 @@ import {
import
{
delay
,
getCheckpointDir
,
getExperimentRootDir
,
getLogDir
,
getMsgDispatcherCommand
,
mkDirP
,
getLogLevel
}
from
'
../common/utils
'
;
import
{
delay
,
getCheckpointDir
,
getExperimentRootDir
,
getLogDir
,
getMsgDispatcherCommand
,
mkDirP
,
getLogLevel
}
from
'
../common/utils
'
;
import
{
import
{
ADD_CUSTOMIZED_TRIAL_JOB
,
INITIALIZE
,
INITIALIZED
,
KILL_TRIAL_JOB
,
NEW_TRIAL_JOB
,
NO_MORE_TRIAL_JOBS
,
PING
,
ADD_CUSTOMIZED_TRIAL_JOB
,
INITIALIZE
,
INITIALIZED
,
KILL_TRIAL_JOB
,
NEW_TRIAL_JOB
,
NO_MORE_TRIAL_JOBS
,
PING
,
REPORT_METRIC_DATA
,
REQUEST_TRIAL_JOBS
,
SEND_TRIAL_JOB_PARAMETER
,
TERMINATE
,
TRIAL_END
,
UPDATE_SEARCH_SPACE
REPORT_METRIC_DATA
,
REQUEST_TRIAL_JOBS
,
SEND_TRIAL_JOB_PARAMETER
,
TERMINATE
,
TRIAL_END
,
UPDATE_SEARCH_SPACE
,
IMPORT_DATA
}
from
'
./commands
'
;
}
from
'
./commands
'
;
import
{
createDispatcherInterface
,
IpcInterface
}
from
'
./ipcInterface
'
;
import
{
createDispatcherInterface
,
IpcInterface
}
from
'
./ipcInterface
'
;
...
@@ -99,6 +99,17 @@ class NNIManager implements Manager {
...
@@ -99,6 +99,17 @@ class NNIManager implements Manager {
return
this
.
storeExperimentProfile
();
return
this
.
storeExperimentProfile
();
}
}
public
importData
(
data
:
string
):
Promise
<
void
>
{
if
(
this
.
dispatcher
===
undefined
)
{
return
Promise
.
reject
(
new
Error
(
'
tuner has not been setup
'
)
);
}
this
.
dispatcher
.
sendCommand
(
IMPORT_DATA
,
data
);
return
this
.
dataStore
.
storeTrialJobEvent
(
'
IMPORT_DATA
'
,
''
,
data
);
}
public
addCustomizedTrialJob
(
hyperParams
:
string
):
Promise
<
void
>
{
public
addCustomizedTrialJob
(
hyperParams
:
string
):
Promise
<
void
>
{
if
(
this
.
currSubmittedTrialNum
>=
this
.
experimentProfile
.
params
.
maxTrialNum
)
{
if
(
this
.
currSubmittedTrialNum
>=
this
.
experimentProfile
.
params
.
maxTrialNum
)
{
return
Promise
.
reject
(
return
Promise
.
reject
(
...
...
src/nni_manager/rest_server/restHandler.ts
View file @
c8ef4141
...
@@ -63,6 +63,7 @@ class NNIRestHandler {
...
@@ -63,6 +63,7 @@ class NNIRestHandler {
this
.
checkStatus
(
router
);
this
.
checkStatus
(
router
);
this
.
getExperimentProfile
(
router
);
this
.
getExperimentProfile
(
router
);
this
.
updateExperimentProfile
(
router
);
this
.
updateExperimentProfile
(
router
);
this
.
importData
(
router
);
this
.
startExperiment
(
router
);
this
.
startExperiment
(
router
);
this
.
getTrialJobStatistics
(
router
);
this
.
getTrialJobStatistics
(
router
);
this
.
setClusterMetaData
(
router
);
this
.
setClusterMetaData
(
router
);
...
@@ -144,6 +145,16 @@ class NNIRestHandler {
...
@@ -144,6 +145,16 @@ class NNIRestHandler {
});
});
});
});
}
}
private
importData
(
router
:
Router
):
void
{
router
.
post
(
'
/experiment/import-data
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
nniManager
.
importData
(
JSON
.
stringify
(
req
.
body
)).
then
(()
=>
{
res
.
send
();
}).
catch
((
err
:
Error
)
=>
{
this
.
handle_error
(
err
,
res
);
});
});
}
private
startExperiment
(
router
:
Router
):
void
{
private
startExperiment
(
router
:
Router
):
void
{
router
.
post
(
'
/experiment
'
,
expressJoi
(
ValidationSchemas
.
STARTEXPERIMENT
),
(
req
:
Request
,
res
:
Response
)
=>
{
router
.
post
(
'
/experiment
'
,
expressJoi
(
ValidationSchemas
.
STARTEXPERIMENT
),
(
req
:
Request
,
res
:
Response
)
=>
{
...
...
src/nni_manager/rest_server/test/mockedNNIManager.ts
View file @
c8ef4141
...
@@ -46,6 +46,9 @@ export class MockedNNIManager extends Manager {
...
@@ -46,6 +46,9 @@ export class MockedNNIManager extends Manager {
public
updateExperimentProfile
(
experimentProfile
:
ExperimentProfile
,
updateType
:
ProfileUpdateType
):
Promise
<
void
>
{
public
updateExperimentProfile
(
experimentProfile
:
ExperimentProfile
,
updateType
:
ProfileUpdateType
):
Promise
<
void
>
{
return
Promise
.
resolve
();
return
Promise
.
resolve
();
}
}
public
importData
(
data
:
string
):
Promise
<
void
>
{
return
Promise
.
resolve
();
}
public
getTrialJobStatistics
():
Promise
<
TrialJobStatistics
[]
>
{
public
getTrialJobStatistics
():
Promise
<
TrialJobStatistics
[]
>
{
const
deferred
:
Deferred
<
TrialJobStatistics
[]
>
=
new
Deferred
<
TrialJobStatistics
[]
>
();
const
deferred
:
Deferred
<
TrialJobStatistics
[]
>
=
new
Deferred
<
TrialJobStatistics
[]
>
();
deferred
.
resolve
([{
deferred
.
resolve
([{
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
c8ef4141
...
@@ -109,18 +109,24 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -109,18 +109,24 @@ class MsgDispatcher(MsgDispatcherBase):
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
self
.
tuner
.
update_search_space
(
data
)
self
.
tuner
.
update_search_space
(
data
)
def
handle_import_data
(
self
,
data
):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self
.
tuner
.
import_data
(
data
)
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
# data: parameters
# data: parameters
id_
=
_create_parameter_id
()
id_
=
_create_parameter_id
()
_customized_parameter_ids
.
add
(
id_
)
_customized_parameter_ids
.
add
(
id_
)
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
id_
,
data
,
customized
=
True
))
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
id_
,
data
,
customized
=
True
))
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
"""
"""
:param
data: a dict received from nni_manager, which contains:
data: a dict received from nni_manager, which contains:
- 'parameter_id': id of the trial
- 'parameter_id': id of the trial
- 'value': metric value reported by nni.report_final_result()
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
"""
if
data
[
'type'
]
==
'FINAL'
:
if
data
[
'type'
]
==
'FINAL'
:
self
.
_handle_final_metric_data
(
data
)
self
.
_handle_final_metric_data
(
data
)
...
@@ -135,9 +141,9 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -135,9 +141,9 @@ class MsgDispatcher(MsgDispatcherBase):
def
handle_trial_end
(
self
,
data
):
def
handle_trial_end
(
self
,
data
):
"""
"""
data: it has three keys: trial_job_id, event, hyper_params
data: it has three keys: trial_job_id, event, hyper_params
trial_job_id: the id generated by training service
-
trial_job_id: the id generated by training service
event: the job's state
-
event: the job's state
hyper_params: the hyperparameters generated and returned by tuner
-
hyper_params: the hyperparameters generated and returned by tuner
"""
"""
trial_job_id
=
data
[
'trial_job_id'
]
trial_job_id
=
data
[
'trial_job_id'
]
_ended_trials
.
add
(
trial_job_id
)
_ended_trials
.
add
(
trial_job_id
)
...
...
src/sdk/pynni/nni/msg_dispatcher_base.py
View file @
c8ef4141
...
@@ -144,6 +144,7 @@ class MsgDispatcherBase(Recoverable):
...
@@ -144,6 +144,7 @@ class MsgDispatcherBase(Recoverable):
CommandType
.
Initialize
:
self
.
handle_initialize
,
CommandType
.
Initialize
:
self
.
handle_initialize
,
CommandType
.
RequestTrialJobs
:
self
.
handle_request_trial_jobs
,
CommandType
.
RequestTrialJobs
:
self
.
handle_request_trial_jobs
,
CommandType
.
UpdateSearchSpace
:
self
.
handle_update_search_space
,
CommandType
.
UpdateSearchSpace
:
self
.
handle_update_search_space
,
CommandType
.
ImportData
:
self
.
handle_import_data
,
CommandType
.
AddCustomizedTrialJob
:
self
.
handle_add_customized_trial
,
CommandType
.
AddCustomizedTrialJob
:
self
.
handle_add_customized_trial
,
# Tunner/Assessor commands:
# Tunner/Assessor commands:
...
@@ -168,6 +169,9 @@ class MsgDispatcherBase(Recoverable):
...
@@ -168,6 +169,9 @@ class MsgDispatcherBase(Recoverable):
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
raise
NotImplementedError
(
'handle_update_search_space not implemented'
)
raise
NotImplementedError
(
'handle_update_search_space not implemented'
)
def
handle_import_data
(
self
,
data
):
raise
NotImplementedError
(
'handle_import_data not implemented'
)
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
raise
NotImplementedError
(
'handle_add_customized_trial not implemented'
)
raise
NotImplementedError
(
'handle_add_customized_trial not implemented'
)
...
...
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
View file @
c8ef4141
...
@@ -112,6 +112,13 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
...
@@ -112,6 +112,13 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
self
.
tuner
.
update_search_space
(
data
)
self
.
tuner
.
update_search_space
(
data
)
return
True
return
True
def
handle_import_data
(
self
,
data
):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self
.
tuner
.
import_data
(
data
)
return
True
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
# data: parameters
# data: parameters
id_
=
_create_parameter_id
()
id_
=
_create_parameter_id
()
...
...
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
View file @
c8ef4141
...
@@ -76,6 +76,12 @@ class MultiPhaseTuner(Recoverable):
...
@@ -76,6 +76,12 @@ class MultiPhaseTuner(Recoverable):
"""
"""
raise
NotImplementedError
(
'Tuner: update_search_space not implemented'
)
raise
NotImplementedError
(
'Tuner: update_search_space not implemented'
)
def
import_data
(
self
,
data
):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def
load_checkpoint
(
self
):
def
load_checkpoint
(
self
):
"""Load the checkpoint of tuner.
"""Load the checkpoint of tuner.
path: checkpoint directory for tuner
path: checkpoint directory for tuner
...
...
src/sdk/pynni/nni/protocol.py
View file @
c8ef4141
...
@@ -30,6 +30,7 @@ class CommandType(Enum):
...
@@ -30,6 +30,7 @@ class CommandType(Enum):
RequestTrialJobs
=
b
'GE'
RequestTrialJobs
=
b
'GE'
ReportMetricData
=
b
'ME'
ReportMetricData
=
b
'ME'
UpdateSearchSpace
=
b
'SS'
UpdateSearchSpace
=
b
'SS'
ImportData
=
b
'FD'
AddCustomizedTrialJob
=
b
'AD'
AddCustomizedTrialJob
=
b
'AD'
TrialEnd
=
b
'EN'
TrialEnd
=
b
'EN'
Terminate
=
b
'TE'
Terminate
=
b
'TE'
...
...
src/sdk/pynni/nni/tuner.py
View file @
c8ef4141
...
@@ -98,6 +98,12 @@ class Tuner(Recoverable):
...
@@ -98,6 +98,12 @@ class Tuner(Recoverable):
checkpoin_path
=
self
.
get_checkpoint_path
()
checkpoin_path
=
self
.
get_checkpoint_path
()
_logger
.
info
(
'Save checkpoint ignored by tuner, checkpoint path: %s'
%
checkpoin_path
)
_logger
.
info
(
'Save checkpoint ignored by tuner, checkpoint path: %s'
%
checkpoin_path
)
def
import_data
(
self
,
data
):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def
_on_exit
(
self
):
def
_on_exit
(
self
):
pass
pass
...
...
tools/nni_cmd/constants.py
View file @
c8ef4141
...
@@ -80,6 +80,20 @@ PACKAGE_REQUIREMENTS = {
...
@@ -80,6 +80,20 @@ PACKAGE_REQUIREMENTS = {
'BOHB'
:
'bohb_advisor'
'BOHB'
:
'bohb_advisor'
}
}
TUNERS_SUPPORTING_IMPORT_DATA
=
{
'TPE'
,
'Anneal'
,
'GridSearch'
,
'MetisTuner'
,
'BOHB'
}
TUNERS_NO_NEED_TO_IMPORT_DATA
=
{
'Random'
,
'Batch_tuner'
,
'Hyperband'
}
COLOR_RED_FORMAT
=
'
\033
[1;31;31m%s
\033
[0m'
COLOR_RED_FORMAT
=
'
\033
[1;31;31m%s
\033
[0m'
COLOR_GREEN_FORMAT
=
'
\033
[1;32;32m%s
\033
[0m'
COLOR_GREEN_FORMAT
=
'
\033
[1;32;32m%s
\033
[0m'
...
...
tools/nni_cmd/nnictl.py
View file @
c8ef4141
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
import
argparse
import
argparse
import
pkg_resources
import
pkg_resources
from
.launcher
import
create_experiment
,
resume_experiment
from
.launcher
import
create_experiment
,
resume_experiment
from
.updater
import
update_searchspace
,
update_concurrency
,
update_duration
,
update_trialnum
from
.updater
import
update_searchspace
,
update_concurrency
,
update_duration
,
update_trialnum
,
import_data
from
.nnictl_utils
import
*
from
.nnictl_utils
import
*
from
.package_management
import
*
from
.package_management
import
*
from
.constants
import
*
from
.constants
import
*
...
@@ -101,10 +101,6 @@ def parse_args():
...
@@ -101,10 +101,6 @@ def parse_args():
parser_trial_kill
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
parser_trial_kill
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
parser_trial_kill
.
add_argument
(
'--trial_id'
,
'-T'
,
required
=
True
,
dest
=
'trial_id'
,
help
=
'the id of trial to be killed'
)
parser_trial_kill
.
add_argument
(
'--trial_id'
,
'-T'
,
required
=
True
,
dest
=
'trial_id'
,
help
=
'the id of trial to be killed'
)
parser_trial_kill
.
set_defaults
(
func
=
trial_kill
)
parser_trial_kill
.
set_defaults
(
func
=
trial_kill
)
parser_trial_export
=
parser_trial_subparsers
.
add_parser
(
'export'
,
help
=
'export trial job results to csv'
)
parser_trial_export
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
parser_trial_export
.
add_argument
(
'--file'
,
'-f'
,
required
=
True
,
dest
=
'csv_path'
,
help
=
'target csv file path'
)
parser_trial_export
.
set_defaults
(
func
=
export_trials_data
)
#parse experiment command
#parse experiment command
parser_experiment
=
subparsers
.
add_parser
(
'experiment'
,
help
=
'get experiment information'
)
parser_experiment
=
subparsers
.
add_parser
(
'experiment'
,
help
=
'get experiment information'
)
...
@@ -119,6 +115,17 @@ def parse_args():
...
@@ -119,6 +115,17 @@ def parse_args():
parser_experiment_list
=
parser_experiment_subparsers
.
add_parser
(
'list'
,
help
=
'list all of running experiment ids'
)
parser_experiment_list
=
parser_experiment_subparsers
.
add_parser
(
'list'
,
help
=
'list all of running experiment ids'
)
parser_experiment_list
.
add_argument
(
'all'
,
nargs
=
'?'
,
help
=
'list all of experiments'
)
parser_experiment_list
.
add_argument
(
'all'
,
nargs
=
'?'
,
help
=
'list all of experiments'
)
parser_experiment_list
.
set_defaults
(
func
=
experiment_list
)
parser_experiment_list
.
set_defaults
(
func
=
experiment_list
)
#import tuning data
parser_import_data
=
parser_experiment_subparsers
.
add_parser
(
'import'
,
help
=
'import additional data'
)
parser_import_data
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
parser_import_data
.
add_argument
(
'--filename'
,
'-f'
,
required
=
True
)
parser_import_data
.
set_defaults
(
func
=
import_data
)
#export trial data
parser_trial_export
=
parser_experiment_subparsers
.
add_parser
(
'export'
,
help
=
'export trial job results to csv or json'
)
parser_trial_export
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
parser_trial_export
.
add_argument
(
'--type'
,
'-t'
,
choices
=
[
'json'
,
'csv'
],
required
=
True
,
dest
=
'type'
,
help
=
'target file type'
)
parser_trial_export
.
add_argument
(
'--filename'
,
'-f'
,
required
=
True
,
dest
=
'path'
,
help
=
'target file path'
)
parser_trial_export
.
set_defaults
(
func
=
export_trials_data
)
#TODO:finish webui function
#TODO:finish webui function
#parse board command
#parse board command
...
...
tools/nni_cmd/nnictl_utils.py
View file @
c8ef4141
...
@@ -505,10 +505,19 @@ def export_trials_data(args):
...
@@ -505,10 +505,19 @@ def export_trials_data(args):
# dframe = pd.DataFrame.from_records([parse_trial_data(t_data) for t_data in content])
# dframe = pd.DataFrame.from_records([parse_trial_data(t_data) for t_data in content])
# dframe.to_csv(args.csv_path, sep='\t')
# dframe.to_csv(args.csv_path, sep='\t')
records
=
parse_trial_data
(
content
)
records
=
parse_trial_data
(
content
)
with
open
(
args
.
csv_path
,
'w'
)
as
f_csv
:
if
args
.
type
==
'json'
:
writer
=
csv
.
DictWriter
(
f_csv
,
set
.
union
(
*
[
set
(
r
.
keys
())
for
r
in
records
]))
json_records
=
[]
writer
.
writeheader
()
for
trial
in
records
:
writer
.
writerows
(
records
)
value
=
trial
.
pop
(
'reward'
,
None
)
trial_id
=
trial
.
pop
(
'id'
,
None
)
json_records
.
append
({
'parameter'
:
trial
,
'value'
:
value
,
'id'
:
trial_id
})
with
open
(
args
.
path
,
'w'
)
as
file
:
if
args
.
type
==
'csv'
:
writer
=
csv
.
DictWriter
(
file
,
set
.
union
(
*
[
set
(
r
.
keys
())
for
r
in
records
]))
writer
.
writeheader
()
writer
.
writerows
(
records
)
else
:
json
.
dump
(
json_records
,
file
)
else
:
else
:
print_error
(
'Export failed...'
)
print_error
(
'Export failed...'
)
else
:
else
:
...
...
tools/nni_cmd/updater.py
View file @
c8ef4141
...
@@ -21,13 +21,13 @@
...
@@ -21,13 +21,13 @@
import
json
import
json
import
os
import
os
from
.rest_utils
import
rest_put
,
rest_get
,
check_rest_server_quick
,
check_response
from
.rest_utils
import
rest_put
,
rest_post
,
rest_get
,
check_rest_server_quick
,
check_response
from
.url_utils
import
experiment_url
from
.url_utils
import
experiment_url
,
import_data_url
from
.config_utils
import
Config
from
.config_utils
import
Config
from
.common_utils
import
get_json_content
from
.common_utils
import
get_json_content
,
print_normal
,
print_error
,
print_warning
from
.nnictl_utils
import
check_experiment_id
,
get_experiment_port
,
get_config_filename
from
.nnictl_utils
import
check_experiment_id
,
get_experiment_port
,
get_config_filename
from
.launcher_utils
import
parse_time
from
.launcher_utils
import
parse_time
from
.constants
import
REST_TIME_OUT
from
.constants
import
REST_TIME_OUT
,
TUNERS_SUPPORTING_IMPORT_DATA
,
TUNERS_NO_NEED_TO_IMPORT_DATA
def
validate_digit
(
value
,
start
,
end
):
def
validate_digit
(
value
,
start
,
end
):
'''validate if a digit is valid'''
'''validate if a digit is valid'''
...
@@ -39,6 +39,23 @@ def validate_file(path):
...
@@ -39,6 +39,23 @@ def validate_file(path):
if
not
os
.
path
.
exists
(
path
):
if
not
os
.
path
.
exists
(
path
):
raise
FileNotFoundError
(
'%s is not a valid file path'
%
path
)
raise
FileNotFoundError
(
'%s is not a valid file path'
%
path
)
def
validate_dispatcher
(
args
):
'''validate if the dispatcher of the experiment supports importing data'''
nni_config
=
Config
(
get_config_filename
(
args
)).
get_config
(
'experimentConfig'
)
if
nni_config
.
get
(
'tuner'
)
and
nni_config
[
'tuner'
].
get
(
'builtinTunerName'
):
dispatcher_name
=
nni_config
[
'tuner'
][
'builtinTunerName'
]
elif
nni_config
.
get
(
'advisor'
)
and
nni_config
[
'advisor'
].
get
(
'builtinAdvisorName'
):
dispatcher_name
=
nni_config
[
'advisor'
][
'builtinAdvisorName'
]
else
:
# otherwise it should be a customized one
return
if
dispatcher_name
not
in
TUNERS_SUPPORTING_IMPORT_DATA
:
if
dispatcher_name
in
TUNERS_NO_NEED_TO_IMPORT_DATA
:
print_warning
(
"There is no need to import data for %s"
%
dispatcher_name
)
exit
(
0
)
else
:
print_error
(
"%s does not support importing addtional data"
%
dispatcher_name
)
exit
(
1
)
def
load_search_space
(
path
):
def
load_search_space
(
path
):
'''load search space content'''
'''load search space content'''
content
=
json
.
dumps
(
get_json_content
(
path
))
content
=
json
.
dumps
(
get_json_content
(
path
))
...
@@ -71,7 +88,7 @@ def update_experiment_profile(args, key, value):
...
@@ -71,7 +88,7 @@ def update_experiment_profile(args, key, value):
if
response
and
check_response
(
response
):
if
response
and
check_response
(
response
):
return
response
return
response
else
:
else
:
print
(
'ERROR: r
estful server is not running...'
)
print
_error
(
'R
estful server is not running...'
)
return
None
return
None
def
update_searchspace
(
args
):
def
update_searchspace
(
args
):
...
@@ -80,18 +97,19 @@ def update_searchspace(args):
...
@@ -80,18 +97,19 @@ def update_searchspace(args):
args
.
port
=
get_experiment_port
(
args
)
args
.
port
=
get_experiment_port
(
args
)
if
args
.
port
is
not
None
:
if
args
.
port
is
not
None
:
if
update_experiment_profile
(
args
,
'searchSpace'
,
content
):
if
update_experiment_profile
(
args
,
'searchSpace'
,
content
):
print
(
'INFO: u
pdate %s success!'
%
'searchSpace'
)
print
_normal
(
'U
pdate %s success!'
%
'searchSpace'
)
else
:
else
:
print
(
'ERROR: update %s failed!'
%
'searchSpace'
)
print_error
(
'Update %s failed!'
%
'searchSpace'
)
def
update_concurrency
(
args
):
def
update_concurrency
(
args
):
validate_digit
(
args
.
value
,
1
,
1000
)
validate_digit
(
args
.
value
,
1
,
1000
)
args
.
port
=
get_experiment_port
(
args
)
args
.
port
=
get_experiment_port
(
args
)
if
args
.
port
is
not
None
:
if
args
.
port
is
not
None
:
if
update_experiment_profile
(
args
,
'trialConcurrency'
,
int
(
args
.
value
)):
if
update_experiment_profile
(
args
,
'trialConcurrency'
,
int
(
args
.
value
)):
print
(
'INFO: u
pdate %s success!'
%
'concurrency'
)
print
_normal
(
'U
pdate %s success!'
%
'concurrency'
)
else
:
else
:
print
(
'ERROR: u
pdate %s failed!'
%
'concurrency'
)
print
_error
(
'U
pdate %s failed!'
%
'concurrency'
)
def
update_duration
(
args
):
def
update_duration
(
args
):
#parse time, change time unit to seconds
#parse time, change time unit to seconds
...
@@ -99,13 +117,38 @@ def update_duration(args):
...
@@ -99,13 +117,38 @@ def update_duration(args):
args
.
port
=
get_experiment_port
(
args
)
args
.
port
=
get_experiment_port
(
args
)
if
args
.
port
is
not
None
:
if
args
.
port
is
not
None
:
if
update_experiment_profile
(
args
,
'maxExecDuration'
,
int
(
args
.
value
)):
if
update_experiment_profile
(
args
,
'maxExecDuration'
,
int
(
args
.
value
)):
print
(
'INFO: u
pdate %s success!'
%
'duration'
)
print
_normal
(
'U
pdate %s success!'
%
'duration'
)
else
:
else
:
print
(
'ERROR: u
pdate %s failed!'
%
'duration'
)
print
_error
(
'U
pdate %s failed!'
%
'duration'
)
def
update_trialnum
(
args
):
def
update_trialnum
(
args
):
validate_digit
(
args
.
value
,
1
,
999999999
)
validate_digit
(
args
.
value
,
1
,
999999999
)
if
update_experiment_profile
(
args
,
'maxTrialNum'
,
int
(
args
.
value
)):
if
update_experiment_profile
(
args
,
'maxTrialNum'
,
int
(
args
.
value
)):
print
(
'INFO: u
pdate %s success!'
%
'trialnum'
)
print
_normal
(
'U
pdate %s success!'
%
'trialnum'
)
else
:
else
:
print
(
'ERROR: update %s failed!'
%
'trialnum'
)
print_error
(
'Update %s failed!'
%
'trialnum'
)
\ No newline at end of file
def
import_data
(
args
):
'''import additional data to the experiment'''
validate_file
(
args
.
filename
)
validate_dispatcher
(
args
)
content
=
load_search_space
(
args
.
filename
)
args
.
port
=
get_experiment_port
(
args
)
if
args
.
port
is
not
None
:
if
import_data_to_restful_server
(
args
,
content
):
print_normal
(
'Import data success!'
)
else
:
print_error
(
'Import data failed!'
)
def
import_data_to_restful_server
(
args
,
content
):
'''call restful server to import data to the experiment'''
nni_config
=
Config
(
get_config_filename
(
args
))
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
running
,
_
=
check_rest_server_quick
(
rest_port
)
if
running
:
response
=
rest_post
(
import_data_url
(
rest_port
),
content
,
REST_TIME_OUT
)
if
response
and
check_response
(
response
):
return
response
else
:
print_error
(
'Restful server is not running...'
)
return
None
tools/nni_cmd/url_utils.py
View file @
c8ef4141
...
@@ -29,6 +29,8 @@ EXPERIMENT_API = '/experiment'
...
@@ -29,6 +29,8 @@ EXPERIMENT_API = '/experiment'
CLUSTER_METADATA_API
=
'/experiment/cluster-metadata'
CLUSTER_METADATA_API
=
'/experiment/cluster-metadata'
IMPORT_DATA_API
=
'/experiment/import-data'
CHECK_STATUS_API
=
'/check-status'
CHECK_STATUS_API
=
'/check-status'
TRIAL_JOBS_API
=
'/trial-jobs'
TRIAL_JOBS_API
=
'/trial-jobs'
...
@@ -46,6 +48,11 @@ def cluster_metadata_url(port):
...
@@ -46,6 +48,11 @@ def cluster_metadata_url(port):
return
'{0}:{1}{2}{3}'
.
format
(
BASE_URL
,
port
,
API_ROOT_URL
,
CLUSTER_METADATA_API
)
return
'{0}:{1}{2}{3}'
.
format
(
BASE_URL
,
port
,
API_ROOT_URL
,
CLUSTER_METADATA_API
)
def
import_data_url
(
port
):
'''get import_data_url'''
return
'{0}:{1}{2}{3}'
.
format
(
BASE_URL
,
port
,
API_ROOT_URL
,
IMPORT_DATA_API
)
def
experiment_url
(
port
):
def
experiment_url
(
port
):
'''get experiment_url'''
'''get experiment_url'''
return
'{0}:{1}{2}{3}'
.
format
(
BASE_URL
,
port
,
API_ROOT_URL
,
EXPERIMENT_API
)
return
'{0}:{1}{2}{3}'
.
format
(
BASE_URL
,
port
,
API_ROOT_URL
,
EXPERIMENT_API
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment