Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a5d614de
Unverified
Commit
a5d614de
authored
Nov 22, 2018
by
chicm-ms
Committed by
GitHub
Nov 22, 2018
Browse files
Asynchronous dispatcher (#372)
* Asynchronous dispatcher * updates * updates * updates * updates
parent
8d63b108
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
126 additions
and
31 deletions
+126
-31
src/nni_manager/common/manager.ts
src/nni_manager/common/manager.ts
+1
-0
src/nni_manager/common/utils.ts
src/nni_manager/common/utils.ts
+5
-1
src/nni_manager/core/commands.ts
src/nni_manager/core/commands.ts
+3
-0
src/nni_manager/core/nnimanager.ts
src/nni_manager/core/nnimanager.ts
+39
-12
src/nni_manager/rest_server/restValidationSchemas.ts
src/nni_manager/rest_server/restValidationSchemas.ts
+1
-0
src/sdk/pynni/nni/__main__.py
src/sdk/pynni/nni/__main__.py
+4
-0
src/sdk/pynni/nni/common.py
src/sdk/pynni/nni/common.py
+9
-0
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+11
-2
src/sdk/pynni/nni/msg_dispatcher_base.py
src/sdk/pynni/nni/msg_dispatcher_base.py
+25
-10
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
+8
-0
src/sdk/pynni/nni/protocol.py
src/sdk/pynni/nni/protocol.py
+17
-6
tools/nni_cmd/config_schema.py
tools/nni_cmd/config_schema.py
+1
-0
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+2
-0
No files found.
src/nni_manager/common/manager.ts
View file @
a5d614de
...
@@ -34,6 +34,7 @@ interface ExperimentParams {
...
@@ -34,6 +34,7 @@ interface ExperimentParams {
searchSpace
:
string
;
searchSpace
:
string
;
trainingServicePlatform
:
string
;
trainingServicePlatform
:
string
;
multiPhase
?:
boolean
;
multiPhase
?:
boolean
;
multiThread
?:
boolean
;
tuner
:
{
tuner
:
{
className
:
string
;
className
:
string
;
builtinTunerName
?:
string
;
builtinTunerName
?:
string
;
...
...
src/nni_manager/common/utils.ts
View file @
a5d614de
...
@@ -158,12 +158,16 @@ function parseArg(names: string[]): string {
...
@@ -158,12 +158,16 @@ function parseArg(names: string[]): string {
* @param assessor: similiar as tuner
* @param assessor: similiar as tuner
*
*
*/
*/
function
getMsgDispatcherCommand
(
tuner
:
any
,
assessor
:
any
,
multiPhase
:
boolean
=
false
):
string
{
function
getMsgDispatcherCommand
(
tuner
:
any
,
assessor
:
any
,
multiPhase
:
boolean
=
false
,
multiThread
:
boolean
=
false
):
string
{
let
command
:
string
=
`python3 -m nni --tuner_class_name
${
tuner
.
className
}
`
;
let
command
:
string
=
`python3 -m nni --tuner_class_name
${
tuner
.
className
}
`
;
if
(
multiPhase
)
{
if
(
multiPhase
)
{
command
+=
'
--multi_phase
'
;
command
+=
'
--multi_phase
'
;
}
}
if
(
multiThread
)
{
command
+=
'
--multi_thread
'
;
}
if
(
tuner
.
classArgs
!==
undefined
)
{
if
(
tuner
.
classArgs
!==
undefined
)
{
command
+=
` --tuner_args
${
JSON
.
stringify
(
JSON
.
stringify
(
tuner
.
classArgs
))}
`
;
command
+=
` --tuner_args
${
JSON
.
stringify
(
JSON
.
stringify
(
tuner
.
classArgs
))}
`
;
}
}
...
...
src/nni_manager/core/commands.ts
View file @
a5d614de
...
@@ -26,6 +26,7 @@ const ADD_CUSTOMIZED_TRIAL_JOB = 'AD';
...
@@ -26,6 +26,7 @@ const ADD_CUSTOMIZED_TRIAL_JOB = 'AD';
const
TRIAL_END
=
'
EN
'
;
const
TRIAL_END
=
'
EN
'
;
const
TERMINATE
=
'
TE
'
;
const
TERMINATE
=
'
TE
'
;
const
INITIALIZED
=
'
ID
'
;
const
NEW_TRIAL_JOB
=
'
TR
'
;
const
NEW_TRIAL_JOB
=
'
TR
'
;
const
SEND_TRIAL_JOB_PARAMETER
=
'
SP
'
;
const
SEND_TRIAL_JOB_PARAMETER
=
'
SP
'
;
const
NO_MORE_TRIAL_JOBS
=
'
NO
'
;
const
NO_MORE_TRIAL_JOBS
=
'
NO
'
;
...
@@ -39,6 +40,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
...
@@ -39,6 +40,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
ADD_CUSTOMIZED_TRIAL_JOB
,
ADD_CUSTOMIZED_TRIAL_JOB
,
TERMINATE
,
TERMINATE
,
INITIALIZED
,
NEW_TRIAL_JOB
,
NEW_TRIAL_JOB
,
SEND_TRIAL_JOB_PARAMETER
,
SEND_TRIAL_JOB_PARAMETER
,
NO_MORE_TRIAL_JOBS
NO_MORE_TRIAL_JOBS
...
@@ -61,6 +63,7 @@ export {
...
@@ -61,6 +63,7 @@ export {
ADD_CUSTOMIZED_TRIAL_JOB
,
ADD_CUSTOMIZED_TRIAL_JOB
,
TRIAL_END
,
TRIAL_END
,
TERMINATE
,
TERMINATE
,
INITIALIZED
,
NEW_TRIAL_JOB
,
NEW_TRIAL_JOB
,
NO_MORE_TRIAL_JOBS
,
NO_MORE_TRIAL_JOBS
,
KILL_TRIAL_JOB
,
KILL_TRIAL_JOB
,
...
...
src/nni_manager/core/nnimanager.ts
View file @
a5d614de
...
@@ -37,8 +37,8 @@ import {
...
@@ -37,8 +37,8 @@ import {
}
from
'
../common/trainingService
'
;
}
from
'
../common/trainingService
'
;
import
{
delay
,
getLogDir
,
getMsgDispatcherCommand
}
from
'
../common/utils
'
;
import
{
delay
,
getLogDir
,
getMsgDispatcherCommand
}
from
'
../common/utils
'
;
import
{
import
{
ADD_CUSTOMIZED_TRIAL_JOB
,
KILL_TRIAL_JOB
,
NEW_TRIAL_JOB
,
NO_MORE_TRIAL_JOBS
,
REPORT_METRIC_DATA
,
ADD_CUSTOMIZED_TRIAL_JOB
,
INITIALIZE
,
INITIALIZED
,
KILL_TRIAL_JOB
,
NEW_TRIAL_JOB
,
NO_MORE_TRIAL_JOBS
,
REQUEST_TRIAL_JOBS
,
SEND_TRIAL_JOB_PARAMETER
,
TERMINATE
,
TRIAL_END
,
UPDATE_SEARCH_SPACE
REPORT_METRIC_DATA
,
REQUEST_TRIAL_JOBS
,
SEND_TRIAL_JOB_PARAMETER
,
TERMINATE
,
TRIAL_END
,
UPDATE_SEARCH_SPACE
}
from
'
./commands
'
;
}
from
'
./commands
'
;
import
{
createDispatcherInterface
,
IpcInterface
}
from
'
./ipcInterface
'
;
import
{
createDispatcherInterface
,
IpcInterface
}
from
'
./ipcInterface
'
;
...
@@ -127,7 +127,8 @@ class NNIManager implements Manager {
...
@@ -127,7 +127,8 @@ class NNIManager implements Manager {
this
.
trainingService
.
setClusterMetadata
(
'
multiPhase
'
,
expParams
.
multiPhase
.
toString
());
this
.
trainingService
.
setClusterMetadata
(
'
multiPhase
'
,
expParams
.
multiPhase
.
toString
());
}
}
const
dispatcherCommand
:
string
=
getMsgDispatcherCommand
(
expParams
.
tuner
,
expParams
.
assessor
,
expParams
.
multiPhase
);
const
dispatcherCommand
:
string
=
getMsgDispatcherCommand
(
expParams
.
tuner
,
expParams
.
assessor
,
expParams
.
multiPhase
,
expParams
.
multiThread
);
this
.
log
.
debug
(
`dispatcher command:
${
dispatcherCommand
}
`
);
this
.
log
.
debug
(
`dispatcher command:
${
dispatcherCommand
}
`
);
this
.
setupTuner
(
this
.
setupTuner
(
//expParams.tuner.tunerCommand,
//expParams.tuner.tunerCommand,
...
@@ -159,7 +160,8 @@ class NNIManager implements Manager {
...
@@ -159,7 +160,8 @@ class NNIManager implements Manager {
this
.
trainingService
.
setClusterMetadata
(
'
multiPhase
'
,
expParams
.
multiPhase
.
toString
());
this
.
trainingService
.
setClusterMetadata
(
'
multiPhase
'
,
expParams
.
multiPhase
.
toString
());
}
}
const
dispatcherCommand
:
string
=
getMsgDispatcherCommand
(
expParams
.
tuner
,
expParams
.
assessor
,
expParams
.
multiPhase
);
const
dispatcherCommand
:
string
=
getMsgDispatcherCommand
(
expParams
.
tuner
,
expParams
.
assessor
,
expParams
.
multiPhase
,
expParams
.
multiThread
);
this
.
log
.
debug
(
`dispatcher command:
${
dispatcherCommand
}
`
);
this
.
log
.
debug
(
`dispatcher command:
${
dispatcherCommand
}
`
);
this
.
setupTuner
(
this
.
setupTuner
(
dispatcherCommand
,
dispatcherCommand
,
...
@@ -419,16 +421,20 @@ class NNIManager implements Manager {
...
@@ -419,16 +421,20 @@ class NNIManager implements Manager {
}
else
{
}
else
{
this
.
trialConcurrencyChange
=
requestTrialNum
;
this
.
trialConcurrencyChange
=
requestTrialNum
;
}
}
for
(
let
i
:
number
=
0
;
i
<
requestTrialNum
;
i
++
)
{
const
requestCustomTrialNum
:
number
=
Math
.
min
(
requestTrialNum
,
this
.
customizedTrials
.
length
);
for
(
let
i
:
number
=
0
;
i
<
requestCustomTrialNum
;
i
++
)
{
// ask tuner for more trials
// ask tuner for more trials
if
(
this
.
customizedTrials
.
length
>
0
)
{
if
(
this
.
customizedTrials
.
length
>
0
)
{
const
hyperParams
:
string
|
undefined
=
this
.
customizedTrials
.
shift
();
const
hyperParams
:
string
|
undefined
=
this
.
customizedTrials
.
shift
();
this
.
dispatcher
.
sendCommand
(
ADD_CUSTOMIZED_TRIAL_JOB
,
hyperParams
);
this
.
dispatcher
.
sendCommand
(
ADD_CUSTOMIZED_TRIAL_JOB
,
hyperParams
);
}
else
{
this
.
dispatcher
.
sendCommand
(
REQUEST_TRIAL_JOBS
,
'
1
'
);
}
}
}
}
if
(
requestTrialNum
-
requestCustomTrialNum
>
0
)
{
this
.
requestTrialJobs
(
requestTrialNum
-
requestCustomTrialNum
);
}
// check maxtrialnum and maxduration here
// check maxtrialnum and maxduration here
if
(
this
.
experimentProfile
.
execDuration
>
this
.
experimentProfile
.
params
.
maxExecDuration
||
if
(
this
.
experimentProfile
.
execDuration
>
this
.
experimentProfile
.
params
.
maxExecDuration
||
this
.
currSubmittedTrialNum
>=
this
.
experimentProfile
.
params
.
maxTrialNum
)
{
this
.
currSubmittedTrialNum
>=
this
.
experimentProfile
.
params
.
maxTrialNum
)
{
...
@@ -526,11 +532,9 @@ class NNIManager implements Manager {
...
@@ -526,11 +532,9 @@ class NNIManager implements Manager {
if
(
this
.
dispatcher
===
undefined
)
{
if
(
this
.
dispatcher
===
undefined
)
{
throw
new
Error
(
'
Dispatcher error: tuner has not been setup
'
);
throw
new
Error
(
'
Dispatcher error: tuner has not been setup
'
);
}
}
// TO DO: we should send INITIALIZE command to tuner if user's tuner needs to run init method in tuner
this
.
log
.
debug
(
`Send tuner command: INITIALIZE:
${
this
.
experimentProfile
.
params
.
searchSpace
}
`
);
this
.
log
.
debug
(
`Send tuner command: update search space:
${
this
.
experimentProfile
.
params
.
searchSpace
}
`
);
// Tuner need to be initialized with search space before generating any hyper parameters
this
.
dispatcher
.
sendCommand
(
UPDATE_SEARCH_SPACE
,
this
.
experimentProfile
.
params
.
searchSpace
);
this
.
dispatcher
.
sendCommand
(
INITIALIZE
,
this
.
experimentProfile
.
params
.
searchSpace
);
this
.
log
.
debug
(
`Send tuner command:
${
this
.
experimentProfile
.
params
.
trialConcurrency
}
`
);
this
.
dispatcher
.
sendCommand
(
REQUEST_TRIAL_JOBS
,
String
(
this
.
experimentProfile
.
params
.
trialConcurrency
));
}
}
private
async
onTrialJobMetrics
(
metric
:
TrialJobMetric
):
Promise
<
void
>
{
private
async
onTrialJobMetrics
(
metric
:
TrialJobMetric
):
Promise
<
void
>
{
...
@@ -541,9 +545,32 @@ class NNIManager implements Manager {
...
@@ -541,9 +545,32 @@ class NNIManager implements Manager {
this
.
dispatcher
.
sendCommand
(
REPORT_METRIC_DATA
,
metric
.
data
);
this
.
dispatcher
.
sendCommand
(
REPORT_METRIC_DATA
,
metric
.
data
);
}
}
private
requestTrialJobs
(
jobNum
:
number
):
void
{
if
(
jobNum
<
1
)
{
return
;
}
if
(
this
.
dispatcher
===
undefined
)
{
throw
new
Error
(
'
Dispatcher error: tuner has not been setup
'
);
}
if
(
this
.
experimentProfile
.
params
.
multiThread
)
{
// Send multiple requests to ensure multiple hyper parameters are generated in non-blocking way.
// For a single REQUEST_TRIAL_JOBS request, hyper parameters are generated one by one
// sequentially.
for
(
let
i
:
number
=
0
;
i
<
jobNum
;
i
++
)
{
this
.
dispatcher
.
sendCommand
(
REQUEST_TRIAL_JOBS
,
'
1
'
);
}
}
else
{
this
.
dispatcher
.
sendCommand
(
REQUEST_TRIAL_JOBS
,
String
(
jobNum
));
}
}
private
async
onTunerCommand
(
commandType
:
string
,
content
:
string
):
Promise
<
void
>
{
private
async
onTunerCommand
(
commandType
:
string
,
content
:
string
):
Promise
<
void
>
{
this
.
log
.
info
(
`Command from tuner:
${
commandType
}
,
${
content
}
`
);
this
.
log
.
info
(
`Command from tuner:
${
commandType
}
,
${
content
}
`
);
switch
(
commandType
)
{
switch
(
commandType
)
{
case
INITIALIZED
:
// Tuner is intialized, search space is set, request tuner to generate hyper parameters
this
.
requestTrialJobs
(
this
.
experimentProfile
.
params
.
trialConcurrency
);
break
;
case
NEW_TRIAL_JOB
:
case
NEW_TRIAL_JOB
:
this
.
waitingTrials
.
push
(
content
);
this
.
waitingTrials
.
push
(
content
);
break
;
break
;
...
...
src/nni_manager/rest_server/restValidationSchemas.ts
View file @
a5d614de
...
@@ -68,6 +68,7 @@ export namespace ValidationSchemas {
...
@@ -68,6 +68,7 @@ export namespace ValidationSchemas {
searchSpace
:
joi
.
string
().
required
(),
searchSpace
:
joi
.
string
().
required
(),
maxExecDuration
:
joi
.
number
().
min
(
0
).
required
(),
maxExecDuration
:
joi
.
number
().
min
(
0
).
required
(),
multiPhase
:
joi
.
boolean
(),
multiPhase
:
joi
.
boolean
(),
multiThread
:
joi
.
boolean
(),
tuner
:
joi
.
object
({
tuner
:
joi
.
object
({
builtinTunerName
:
joi
.
string
().
valid
(
'
TPE
'
,
'
Random
'
,
'
Anneal
'
,
'
Evolution
'
,
'
SMAC
'
,
'
BatchTuner
'
,
'
GridSearch
'
),
builtinTunerName
:
joi
.
string
().
valid
(
'
TPE
'
,
'
Random
'
,
'
Anneal
'
,
'
Evolution
'
,
'
SMAC
'
,
'
BatchTuner
'
,
'
GridSearch
'
),
codeDir
:
joi
.
string
(),
codeDir
:
joi
.
string
(),
...
...
src/sdk/pynni/nni/__main__.py
View file @
a5d614de
...
@@ -28,6 +28,7 @@ import json
...
@@ -28,6 +28,7 @@ import json
import
importlib
import
importlib
from
.constants
import
ModuleName
,
ClassName
,
ClassArgs
from
.constants
import
ModuleName
,
ClassName
,
ClassArgs
from
nni.common
import
enable_multi_thread
from
nni.msg_dispatcher
import
MsgDispatcher
from
nni.msg_dispatcher
import
MsgDispatcher
from
nni.multi_phase.multi_phase_dispatcher
import
MultiPhaseMsgDispatcher
from
nni.multi_phase.multi_phase_dispatcher
import
MultiPhaseMsgDispatcher
logger
=
logging
.
getLogger
(
'nni.main'
)
logger
=
logging
.
getLogger
(
'nni.main'
)
...
@@ -91,6 +92,7 @@ def parse_args():
...
@@ -91,6 +92,7 @@ def parse_args():
parser
.
add_argument
(
'--assessor_class_filename'
,
type
=
str
,
required
=
False
,
parser
.
add_argument
(
'--assessor_class_filename'
,
type
=
str
,
required
=
False
,
help
=
'Assessor class file path'
)
help
=
'Assessor class file path'
)
parser
.
add_argument
(
'--multi_phase'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--multi_phase'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--multi_thread'
,
action
=
'store_true'
)
flags
,
_
=
parser
.
parse_known_args
()
flags
,
_
=
parser
.
parse_known_args
()
return
flags
return
flags
...
@@ -101,6 +103,8 @@ def main():
...
@@ -101,6 +103,8 @@ def main():
'''
'''
args
=
parse_args
()
args
=
parse_args
()
if
args
.
multi_thread
:
enable_multi_thread
()
tuner
=
None
tuner
=
None
assessor
=
None
assessor
=
None
...
...
src/sdk/pynni/nni/common.py
View file @
a5d614de
...
@@ -78,3 +78,12 @@ def init_logger(logger_file_path):
...
@@ -78,3 +78,12 @@ def init_logger(logger_file_path):
logging
.
getLogger
(
'matplotlib'
).
setLevel
(
logging
.
INFO
)
logging
.
getLogger
(
'matplotlib'
).
setLevel
(
logging
.
INFO
)
sys
.
stdout
=
_LoggerFileWrapper
(
logger_file
)
sys
.
stdout
=
_LoggerFileWrapper
(
logger_file
)
_multi_thread
=
False
def
enable_multi_thread
():
global
_multi_thread
_multi_thread
=
True
def
multi_thread_enabled
():
return
_multi_thread
src/sdk/pynni/nni/msg_dispatcher.py
View file @
a5d614de
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
import
logging
import
logging
from
collections
import
defaultdict
from
collections
import
defaultdict
import
json_tricks
import
json_tricks
import
threading
from
.protocol
import
CommandType
,
send
from
.protocol
import
CommandType
,
send
from
.msg_dispatcher_base
import
MsgDispatcherBase
from
.msg_dispatcher_base
import
MsgDispatcherBase
...
@@ -69,7 +70,7 @@ def _pack_parameter(parameter_id, params, customized=False):
...
@@ -69,7 +70,7 @@ def _pack_parameter(parameter_id, params, customized=False):
class
MsgDispatcher
(
MsgDispatcherBase
):
class
MsgDispatcher
(
MsgDispatcherBase
):
def
__init__
(
self
,
tuner
,
assessor
=
None
):
def
__init__
(
self
,
tuner
,
assessor
=
None
):
super
()
super
()
.
__init__
()
self
.
tuner
=
tuner
self
.
tuner
=
tuner
self
.
assessor
=
assessor
self
.
assessor
=
assessor
if
assessor
is
None
:
if
assessor
is
None
:
...
@@ -85,6 +86,14 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -85,6 +86,14 @@ class MsgDispatcher(MsgDispatcherBase):
if
self
.
assessor
is
not
None
:
if
self
.
assessor
is
not
None
:
self
.
assessor
.
save_checkpoint
()
self
.
assessor
.
save_checkpoint
()
def
handle_initialize
(
self
,
data
):
'''
data is search space
'''
self
.
tuner
.
update_search_space
(
data
)
send
(
CommandType
.
Initialized
,
''
)
return
True
def
handle_request_trial_jobs
(
self
,
data
):
def
handle_request_trial_jobs
(
self
,
data
):
# data: number or trial jobs
# data: number or trial jobs
ids
=
[
_create_parameter_id
()
for
_
in
range
(
data
)]
ids
=
[
_create_parameter_id
()
for
_
in
range
(
data
)]
...
...
src/sdk/pynni/nni/msg_dispatcher_base.py
View file @
a5d614de
...
@@ -22,8 +22,8 @@
...
@@ -22,8 +22,8 @@
import
os
import
os
import
logging
import
logging
import
json_tricks
import
json_tricks
from
multiprocessing.dummy
import
Pool
as
ThreadPool
from
.common
import
init_logger
from
.common
import
init_logger
,
multi_thread_enabled
from
.recoverable
import
Recoverable
from
.recoverable
import
Recoverable
from
.protocol
import
CommandType
,
receive
from
.protocol
import
CommandType
,
receive
...
@@ -31,6 +31,10 @@ init_logger('dispatcher.log')
...
@@ -31,6 +31,10 @@ init_logger('dispatcher.log')
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
class
MsgDispatcherBase
(
Recoverable
):
class
MsgDispatcherBase
(
Recoverable
):
def
__init__
(
self
):
if
multi_thread_enabled
():
self
.
pool
=
ThreadPool
()
def
run
(
self
):
def
run
(
self
):
"""Run the tuner.
"""Run the tuner.
This function will never return unless raise.
This function will never return unless raise.
...
@@ -39,17 +43,24 @@ class MsgDispatcherBase(Recoverable):
...
@@ -39,17 +43,24 @@ class MsgDispatcherBase(Recoverable):
if
mode
==
'resume'
:
if
mode
==
'resume'
:
self
.
load_checkpoint
()
self
.
load_checkpoint
()
while
self
.
handle_request
():
while
True
:
pass
_logger
.
info
(
'Terminated by NNI manager'
)
def
handle_request
(
self
):
_logger
.
debug
(
'waiting receive_message'
)
_logger
.
debug
(
'waiting receive_message'
)
command
,
data
=
receive
()
command
,
data
=
receive
()
if
command
is
None
:
if
command
is
None
:
return
False
break
if
multi_thread_enabled
():
self
.
pool
.
map_async
(
self
.
handle_request
,
[(
command
,
data
)])
else
:
self
.
handle_request
((
command
,
data
))
if
multi_thread_enabled
():
self
.
pool
.
close
()
self
.
pool
.
join
()
_logger
.
info
(
'Terminated by NNI manager'
)
def
handle_request
(
self
,
request
):
command
,
data
=
request
_logger
.
debug
(
'handle request: command: [{}], data: [{}]'
.
format
(
command
,
data
))
_logger
.
debug
(
'handle request: command: [{}], data: [{}]'
.
format
(
command
,
data
))
...
@@ -60,6 +71,7 @@ class MsgDispatcherBase(Recoverable):
...
@@ -60,6 +71,7 @@ class MsgDispatcherBase(Recoverable):
command_handlers
=
{
command_handlers
=
{
# Tunner commands:
# Tunner commands:
CommandType
.
Initialize
:
self
.
handle_initialize
,
CommandType
.
RequestTrialJobs
:
self
.
handle_request_trial_jobs
,
CommandType
.
RequestTrialJobs
:
self
.
handle_request_trial_jobs
,
CommandType
.
UpdateSearchSpace
:
self
.
handle_update_search_space
,
CommandType
.
UpdateSearchSpace
:
self
.
handle_update_search_space
,
CommandType
.
AddCustomizedTrialJob
:
self
.
handle_add_customized_trial
,
CommandType
.
AddCustomizedTrialJob
:
self
.
handle_add_customized_trial
,
...
@@ -74,6 +86,9 @@ class MsgDispatcherBase(Recoverable):
...
@@ -74,6 +86,9 @@ class MsgDispatcherBase(Recoverable):
return
command_handlers
[
command
](
data
)
return
command_handlers
[
command
](
data
)
def
handle_initialize
(
self
,
data
):
raise
NotImplementedError
(
'handle_initialize not implemented'
)
def
handle_request_trial_jobs
(
self
,
data
):
def
handle_request_trial_jobs
(
self
,
data
):
raise
NotImplementedError
(
'handle_request_trial_jobs not implemented'
)
raise
NotImplementedError
(
'handle_request_trial_jobs not implemented'
)
...
...
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
View file @
a5d614de
...
@@ -91,6 +91,14 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
...
@@ -91,6 +91,14 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
if
self
.
assessor
is
not
None
:
if
self
.
assessor
is
not
None
:
self
.
assessor
.
save_checkpoint
()
self
.
assessor
.
save_checkpoint
()
def
handle_initialize
(
self
,
data
):
'''
data is search space
'''
self
.
tuner
.
update_search_space
(
data
)
send
(
CommandType
.
Initialized
,
''
)
return
True
def
handle_request_trial_jobs
(
self
,
data
):
def
handle_request_trial_jobs
(
self
,
data
):
# data: number or trial jobs
# data: number or trial jobs
ids
=
[
_create_parameter_id
()
for
_
in
range
(
data
)]
ids
=
[
_create_parameter_id
()
for
_
in
range
(
data
)]
...
...
src/sdk/pynni/nni/protocol.py
View file @
a5d614de
...
@@ -19,7 +19,9 @@
...
@@ -19,7 +19,9 @@
# ==================================================================================================
# ==================================================================================================
import
logging
import
logging
import
threading
from
enum
import
Enum
from
enum
import
Enum
from
.common
import
multi_thread_enabled
class
CommandType
(
Enum
):
class
CommandType
(
Enum
):
...
@@ -33,6 +35,7 @@ class CommandType(Enum):
...
@@ -33,6 +35,7 @@ class CommandType(Enum):
Terminate
=
b
'TE'
Terminate
=
b
'TE'
# out
# out
Initialized
=
b
'ID'
NewTrialJob
=
b
'TR'
NewTrialJob
=
b
'TR'
SendTrialJobParameter
=
b
'SP'
SendTrialJobParameter
=
b
'SP'
NoMoreTrialJobs
=
b
'NO'
NoMoreTrialJobs
=
b
'NO'
...
@@ -42,6 +45,7 @@ class CommandType(Enum):
...
@@ -42,6 +45,7 @@ class CommandType(Enum):
try
:
try
:
_in_file
=
open
(
3
,
'rb'
)
_in_file
=
open
(
3
,
'rb'
)
_out_file
=
open
(
4
,
'wb'
)
_out_file
=
open
(
4
,
'wb'
)
_lock
=
threading
.
Lock
()
except
OSError
:
except
OSError
:
_msg
=
'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?'
_msg
=
'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?'
import
logging
import
logging
...
@@ -53,12 +57,19 @@ def send(command, data):
...
@@ -53,12 +57,19 @@ def send(command, data):
command: CommandType object.
command: CommandType object.
data: string payload.
data: string payload.
"""
"""
global
_lock
try
:
if
multi_thread_enabled
():
_lock
.
acquire
()
data
=
data
.
encode
(
'utf8'
)
data
=
data
.
encode
(
'utf8'
)
assert
len
(
data
)
<
1000000
,
'Command too long'
assert
len
(
data
)
<
1000000
,
'Command too long'
msg
=
b
'%b%06d%b'
%
(
command
.
value
,
len
(
data
),
data
)
msg
=
b
'%b%06d%b'
%
(
command
.
value
,
len
(
data
),
data
)
logging
.
getLogger
(
__name__
).
debug
(
'Sending command, data: [%s]'
%
msg
)
logging
.
getLogger
(
__name__
).
debug
(
'Sending command, data: [%s]'
%
msg
)
_out_file
.
write
(
msg
)
_out_file
.
write
(
msg
)
_out_file
.
flush
()
_out_file
.
flush
()
finally
:
if
multi_thread_enabled
():
_lock
.
release
()
def
receive
():
def
receive
():
...
...
tools/nni_cmd/config_schema.py
View file @
a5d614de
...
@@ -31,6 +31,7 @@ Optional('maxTrialNum'): And(int, lambda x: 1 <= x <= 99999),
...
@@ -31,6 +31,7 @@ Optional('maxTrialNum'): And(int, lambda x: 1 <= x <= 99999),
'trainingServicePlatform'
:
And
(
str
,
lambda
x
:
x
in
[
'remote'
,
'local'
,
'pai'
,
'kubeflow'
]),
'trainingServicePlatform'
:
And
(
str
,
lambda
x
:
x
in
[
'remote'
,
'local'
,
'pai'
,
'kubeflow'
]),
Optional
(
'searchSpacePath'
):
os
.
path
.
exists
,
Optional
(
'searchSpacePath'
):
os
.
path
.
exists
,
Optional
(
'multiPhase'
):
bool
,
Optional
(
'multiPhase'
):
bool
,
Optional
(
'multiThread'
):
bool
,
'useAnnotation'
:
bool
,
'useAnnotation'
:
bool
,
'tuner'
:
Or
({
'tuner'
:
Or
({
'builtinTunerName'
:
Or
(
'TPE'
,
'Random'
,
'Anneal'
,
'Evolution'
,
'SMAC'
,
'BatchTuner'
,
'GridSearch'
),
'builtinTunerName'
:
Or
(
'TPE'
,
'Random'
,
'Anneal'
,
'Evolution'
,
'SMAC'
,
'BatchTuner'
,
'GridSearch'
),
...
...
tools/nni_cmd/launcher.py
View file @
a5d614de
...
@@ -196,6 +196,8 @@ def set_experiment(experiment_config, mode, port, config_file_name):
...
@@ -196,6 +196,8 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data
[
'description'
]
=
experiment_config
[
'description'
]
request_data
[
'description'
]
=
experiment_config
[
'description'
]
if
experiment_config
.
get
(
'multiPhase'
):
if
experiment_config
.
get
(
'multiPhase'
):
request_data
[
'multiPhase'
]
=
experiment_config
.
get
(
'multiPhase'
)
request_data
[
'multiPhase'
]
=
experiment_config
.
get
(
'multiPhase'
)
if
experiment_config
.
get
(
'multiThread'
):
request_data
[
'multiThread'
]
=
experiment_config
.
get
(
'multiThread'
)
request_data
[
'tuner'
]
=
experiment_config
[
'tuner'
]
request_data
[
'tuner'
]
=
experiment_config
[
'tuner'
]
if
'assessor'
in
experiment_config
:
if
'assessor'
in
experiment_config
:
request_data
[
'assessor'
]
=
experiment_config
[
'assessor'
]
request_data
[
'assessor'
]
=
experiment_config
[
'assessor'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment