Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
93f96d4f
Unverified
Commit
93f96d4f
authored
Jul 01, 2020
by
SparkSnail
Committed by
GitHub
Jul 01, 2020
Browse files
Support aml (#2615)
parent
f5caa193
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
131 additions
and
36 deletions
+131
-36
src/nni_manager/training_service/reusable/trialDispatcher.ts
src/nni_manager/training_service/reusable/trialDispatcher.ts
+40
-34
src/sdk/pynni/nni/platform/__init__.py
src/sdk/pynni/nni/platform/__init__.py
+1
-1
tools/nni_cmd/config_schema.py
tools/nni_cmd/config_schema.py
+19
-1
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+21
-0
tools/nni_trial_tool/aml_channel.py
tools/nni_trial_tool/aml_channel.py
+47
-0
tools/nni_trial_tool/trial_runner.py
tools/nni_trial_tool/trial_runner.py
+3
-0
No files found.
src/nni_manager/training_service/reusable/trialDispatcher.ts
View file @
93f96d4f
...
...
@@ -9,7 +9,7 @@ import * as path from 'path';
import
{
Writable
}
from
'
stream
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
*
as
component
from
'
../../common/component
'
;
import
{
getExperimentId
,
getPlatform
,
getBasePort
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getBasePort
,
getExperimentId
,
getPlatform
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
NNIManagerIpConfig
,
TrainingService
,
TrialJobApplicationForm
,
TrialJobMetric
,
TrialJobStatus
}
from
'
../../common/trainingService
'
;
import
{
delay
,
getExperimentRootDir
,
getLogLevel
,
getVersion
,
mkDirPSync
,
uniqueString
}
from
'
../../common/utils
'
;
...
...
@@ -19,9 +19,9 @@ import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import
{
TrialConfig
}
from
'
../common/trialConfig
'
;
import
{
TrialConfigMetadataKey
}
from
'
../common/trialConfigMetadataKey
'
;
import
{
validateCodeDir
}
from
'
../common/util
'
;
import
{
WebCommandChannel
}
from
'
./channels/webCommandChannel
'
;
import
{
Command
,
CommandChannel
}
from
'
./commandChannel
'
;
import
{
EnvironmentInformation
,
EnvironmentService
,
NodeInfomation
,
RunnerSettings
}
from
'
./environment
'
;
import
{
MountedStorageService
}
from
'
./storages/mountedStorageService
'
;
import
{
StorageService
}
from
'
./storageService
'
;
import
{
TrialDetail
}
from
'
./trial
'
;
...
...
@@ -40,6 +40,7 @@ class TrialDispatcher implements TrainingService {
private
readonly
metricsEmitter
:
EventEmitter
;
private
readonly
experimentId
:
string
;
private
readonly
experimentRootDir
:
string
;
private
enableVersionCheck
:
boolean
=
true
;
...
...
@@ -58,6 +59,8 @@ class TrialDispatcher implements TrainingService {
this
.
environments
=
new
Map
<
string
,
EnvironmentInformation
>
();
this
.
metricsEmitter
=
new
EventEmitter
();
this
.
experimentId
=
getExperimentId
();
this
.
experimentRootDir
=
getExperimentRootDir
();
this
.
runnerSettings
=
new
RunnerSettings
();
this
.
runnerSettings
.
experimentId
=
this
.
experimentId
;
this
.
runnerSettings
.
platform
=
getPlatform
();
...
...
@@ -158,14 +161,14 @@ class TrialDispatcher implements TrainingService {
const
environmentService
=
component
.
get
<
EnvironmentService
>
(
EnvironmentService
);
this
.
commandEmitter
=
new
EventEmitter
();
this
.
commandChannel
=
new
Web
CommandChannel
(
this
.
commandEmitter
);
this
.
commandChannel
=
environmentService
.
get
CommandChannel
(
this
.
commandEmitter
);
// TODO it's a hard code of web channel, it needs to be improved.
this
.
runnerSettings
.
nniManagerPort
=
getBasePort
()
+
1
;
this
.
runnerSettings
.
commandChannel
=
this
.
commandChannel
.
channelName
;
// for AML channel, other channels can ignore this.
this
.
commandChannel
.
config
(
"
MetricEmitter
"
,
this
.
metricsEmitter
);
await
this
.
commandChannel
.
config
(
"
MetricEmitter
"
,
this
.
metricsEmitter
);
// start channel
this
.
commandEmitter
.
on
(
"
command
"
,
(
command
:
Command
):
void
=>
{
...
...
@@ -173,16 +176,25 @@ class TrialDispatcher implements TrainingService {
this
.
log
.
error
(
`TrialDispatcher: error on handle env
${
command
.
environment
.
id
}
command:
${
command
.
command
}
, data:
${
command
.
data
}
, error:
${
err
}
`
);
})
});
this
.
commandChannel
.
start
();
await
this
.
commandChannel
.
start
();
this
.
log
.
info
(
`TrialDispatcher: started channel:
${
this
.
commandChannel
.
constructor
.
name
}
`
);
if
(
this
.
trialConfig
===
undefined
)
{
throw
new
Error
(
`trial config shouldn't be undefined in run()`
);
}
if
(
environmentService
.
hasStorageService
)
{
this
.
log
.
info
(
`TrialDispatcher: copying code and settings.`
);
const
storageService
=
component
.
get
<
StorageService
>
(
StorageService
);
let
storageService
:
StorageService
;
if
(
environmentService
.
hasStorageService
)
{
this
.
log
.
debug
(
`TrialDispatcher: use existing storage service.`
);
storageService
=
component
.
get
<
StorageService
>
(
StorageService
);
}
else
{
this
.
log
.
debug
(
`TrialDispatcher: create temp storage service to temp folder.`
);
storageService
=
new
MountedStorageService
();
const
environmentLocalTempFolder
=
path
.
join
(
this
.
experimentRootDir
,
this
.
experimentId
,
"
environment-temp
"
);
storageService
.
initialize
(
this
.
trialConfig
.
codeDir
,
environmentLocalTempFolder
);
}
// Copy the compressed file to remoteDirectory and delete it
const
codeDir
=
path
.
resolve
(
this
.
trialConfig
.
codeDir
);
const
envDir
=
storageService
.
joinPath
(
"
envs
"
);
...
...
@@ -202,12 +214,12 @@ class TrialDispatcher implements TrainingService {
}
await
storageService
.
copyDirectory
(
trialToolsPath
,
envDir
,
true
);
}
}
this
.
log
.
info
(
`TrialDispatcher: run loop started.`
);
await
Promise
.
all
([
this
.
environmentMaintenanceLoop
(),
this
.
trialManagementLoop
(),
this
.
commandChannel
.
run
(),
]);
}
...
...
@@ -274,7 +286,7 @@ class TrialDispatcher implements TrainingService {
}
this
.
commandEmitter
.
off
(
"
command
"
,
this
.
handleCommand
);
this
.
commandChannel
.
stop
();
await
this
.
commandChannel
.
stop
();
}
private
async
environmentMaintenanceLoop
():
Promise
<
void
>
{
...
...
@@ -396,7 +408,6 @@ class TrialDispatcher implements TrainingService {
break
;
}
}
let
liveEnvironmentsCount
=
0
;
const
idleEnvironments
:
EnvironmentInformation
[]
=
[];
this
.
environments
.
forEach
((
environment
)
=>
{
...
...
@@ -407,7 +418,6 @@ class TrialDispatcher implements TrainingService {
}
}
});
while
(
idleEnvironments
.
length
>
0
&&
waitingTrials
.
length
>
0
)
{
const
trial
=
waitingTrials
.
shift
();
const
idleEnvironment
=
idleEnvironments
.
shift
();
...
...
@@ -442,14 +452,10 @@ class TrialDispatcher implements TrainingService {
environment
.
command
=
"
[ -d
\"
nni_trial_tool
\"
] && echo
\"
nni_trial_tool exists already
\"
|| (mkdir ./nni_trial_tool && tar -xof ../nni_trial_tool.tar.gz -C ./nni_trial_tool) && pip3 install websockets &&
"
+
environment
.
command
;
}
if
(
environmentService
.
hasStorageService
)
{
const
storageService
=
component
.
get
<
StorageService
>
(
StorageService
);
environment
.
workingFolder
=
storageService
.
joinPath
(
"
envs
"
,
envId
);
await
storageService
.
createDirectory
(
environment
.
workingFolder
);
}
environment
.
command
=
`mkdir -p envs/
${
envId
}
&& cd envs/
${
envId
}
&&
${
environment
.
command
}
`
;
this
.
environments
.
set
(
environment
.
id
,
environment
);
await
environmentService
.
startEnvironment
(
environment
);
this
.
environments
.
set
(
environment
.
id
,
environment
);
if
(
environment
.
status
===
"
FAILED
"
)
{
environment
.
isIdle
=
false
;
...
...
src/sdk/pynni/nni/platform/__init__.py
View file @
93f96d4f
...
...
@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from
.standalone
import
*
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
from
.test
import
*
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
):
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
):
from
.local
import
*
else
:
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
tools/nni_cmd/config_schema.py
View file @
93f96d4f
...
...
@@ -116,7 +116,7 @@ common_schema = {
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
),
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
...
...
@@ -234,6 +234,23 @@ dlts_config_schema = {
}
}
aml_trial_schema
=
{
'trial'
:{
'codeDir'
:
setPathCheck
(
'codeDir'
),
'command'
:
setType
(
'command'
,
str
),
'image'
:
setType
(
'image'
,
str
),
'computeTarget'
:
setType
(
'computeTarget'
,
str
)
}
}
aml_config_schema
=
{
'amlConfig'
:
{
'subscriptionId'
:
setType
(
'subscriptionId'
,
str
),
'resourceGroup'
:
setType
(
'resourceGroup'
,
str
),
'workspaceName'
:
setType
(
'workspaceName'
,
str
),
}
}
kubeflow_trial_schema
=
{
'trial'
:{
'codeDir'
:
setPathCheck
(
'codeDir'
),
...
...
@@ -374,6 +391,7 @@ training_service_schema_dict = {
'paiYarn'
:
Schema
({
**
common_schema
,
**
pai_yarn_trial_schema
,
**
pai_yarn_config_schema
}),
'kubeflow'
:
Schema
({
**
common_schema
,
**
kubeflow_trial_schema
,
**
kubeflow_config_schema
}),
'frameworkcontroller'
:
Schema
({
**
common_schema
,
**
frameworkcontroller_trial_schema
,
**
frameworkcontroller_config_schema
}),
'aml'
:
Schema
({
**
common_schema
,
**
aml_trial_schema
,
**
aml_config_schema
}),
'dlts'
:
Schema
({
**
common_schema
,
**
dlts_trial_schema
,
**
dlts_config_schema
}),
}
...
...
tools/nni_cmd/launcher.py
View file @
93f96d4f
...
...
@@ -272,6 +272,25 @@ def set_dlts_config(experiment_config, port, config_file_name):
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_aml_config
(
experiment_config
,
port
,
config_file_name
):
'''set aml configuration'''
aml_config_data
=
dict
()
aml_config_data
[
'aml_config'
]
=
experiment_config
[
'amlConfig'
]
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
aml_config_data
),
REST_TIME_OUT
)
err_message
=
None
if
not
response
or
not
response
.
status_code
==
200
:
if
response
is
not
None
:
err_message
=
response
.
text
_
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
return
result
,
message
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_experiment
(
experiment_config
,
mode
,
port
,
config_file_name
):
'''Call startExperiment (rest POST /experiment) with yaml file content'''
request_data
=
dict
()
...
...
@@ -374,6 +393,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result
,
err_msg
=
set_frameworkcontroller_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'dlts'
:
config_result
,
err_msg
=
set_dlts_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'aml'
:
config_result
,
err_msg
=
set_aml_config
(
experiment_config
,
port
,
config_file_name
)
else
:
raise
Exception
(
ERROR_INFO
%
'Unsupported platform!'
)
exit
(
1
)
...
...
tools/nni_trial_tool/aml_channel.py
0 → 100644
View file @
93f96d4f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
azureml.core.run
import
Run
# pylint: disable=import-error
from
.base_channel
import
BaseChannel
from
.log_utils
import
LogType
,
nni_log
class
AMLChannel
(
BaseChannel
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
run
=
Run
.
get_context
()
super
(
AMLChannel
,
self
).
__init__
(
args
)
self
.
current_message_index
=
-
1
def
_inner_open
(
self
):
pass
def
_inner_close
(
self
):
pass
def
_inner_send
(
self
,
message
):
try
:
self
.
run
.
log
(
'trial_runner'
,
message
.
decode
(
'utf8'
))
except
Exception
as
exception
:
nni_log
(
LogType
.
Error
,
'meet unhandled exception when send message: %s'
%
exception
)
def
_inner_receive
(
self
):
messages
=
[]
message_dict
=
self
.
run
.
get_metrics
()
if
'nni_manager'
not
in
message_dict
:
return
[]
message_list
=
message_dict
[
'nni_manager'
]
if
not
message_list
:
return
messages
if
type
(
message_list
)
is
list
:
if
self
.
current_message_index
<
len
(
message_list
)
-
1
:
messages
=
message_list
[
self
.
current_message_index
+
1
:
len
(
message_list
)]
self
.
current_message_index
=
len
(
message_list
)
-
1
elif
self
.
current_message_index
==
-
1
:
messages
=
[
message_list
]
self
.
current_message_index
+=
1
newMessage
=
[]
for
message
in
messages
:
# receive message is string, to get consistent result, encode it here.
newMessage
.
append
(
message
.
encode
(
'utf8'
))
return
newMessage
tools/nni_trial_tool/trial_runner.py
View file @
93f96d4f
...
...
@@ -210,6 +210,9 @@ if __name__ == '__main__':
command_channel
=
None
if
args
.
command_channel
==
"file"
:
command_channel
=
FileChannel
(
args
)
elif
args
.
command_channel
==
'aml'
:
from
.aml_channel
import
AMLChannel
command_channel
=
AMLChannel
(
args
)
else
:
command_channel
=
WebChannel
(
args
)
command_channel
.
open
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment