Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
93f96d4f
Unverified
Commit
93f96d4f
authored
Jul 01, 2020
by
SparkSnail
Committed by
GitHub
Jul 01, 2020
Browse files
Support aml (#2615)
parent
f5caa193
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
131 additions
and
36 deletions
+131
-36
src/nni_manager/training_service/reusable/trialDispatcher.ts
src/nni_manager/training_service/reusable/trialDispatcher.ts
+40
-34
src/sdk/pynni/nni/platform/__init__.py
src/sdk/pynni/nni/platform/__init__.py
+1
-1
tools/nni_cmd/config_schema.py
tools/nni_cmd/config_schema.py
+19
-1
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+21
-0
tools/nni_trial_tool/aml_channel.py
tools/nni_trial_tool/aml_channel.py
+47
-0
tools/nni_trial_tool/trial_runner.py
tools/nni_trial_tool/trial_runner.py
+3
-0
No files found.
src/nni_manager/training_service/reusable/trialDispatcher.ts
View file @
93f96d4f
...
@@ -9,7 +9,7 @@ import * as path from 'path';
...
@@ -9,7 +9,7 @@ import * as path from 'path';
import
{
Writable
}
from
'
stream
'
;
import
{
Writable
}
from
'
stream
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
*
as
component
from
'
../../common/component
'
;
import
*
as
component
from
'
../../common/component
'
;
import
{
getExperimentId
,
getPlatform
,
getBasePort
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getBasePort
,
getExperimentId
,
getPlatform
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
NNIManagerIpConfig
,
TrainingService
,
TrialJobApplicationForm
,
TrialJobMetric
,
TrialJobStatus
}
from
'
../../common/trainingService
'
;
import
{
NNIManagerIpConfig
,
TrainingService
,
TrialJobApplicationForm
,
TrialJobMetric
,
TrialJobStatus
}
from
'
../../common/trainingService
'
;
import
{
delay
,
getExperimentRootDir
,
getLogLevel
,
getVersion
,
mkDirPSync
,
uniqueString
}
from
'
../../common/utils
'
;
import
{
delay
,
getExperimentRootDir
,
getLogLevel
,
getVersion
,
mkDirPSync
,
uniqueString
}
from
'
../../common/utils
'
;
...
@@ -19,9 +19,9 @@ import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
...
@@ -19,9 +19,9 @@ import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import
{
TrialConfig
}
from
'
../common/trialConfig
'
;
import
{
TrialConfig
}
from
'
../common/trialConfig
'
;
import
{
TrialConfigMetadataKey
}
from
'
../common/trialConfigMetadataKey
'
;
import
{
TrialConfigMetadataKey
}
from
'
../common/trialConfigMetadataKey
'
;
import
{
validateCodeDir
}
from
'
../common/util
'
;
import
{
validateCodeDir
}
from
'
../common/util
'
;
import
{
WebCommandChannel
}
from
'
./channels/webCommandChannel
'
;
import
{
Command
,
CommandChannel
}
from
'
./commandChannel
'
;
import
{
Command
,
CommandChannel
}
from
'
./commandChannel
'
;
import
{
EnvironmentInformation
,
EnvironmentService
,
NodeInfomation
,
RunnerSettings
}
from
'
./environment
'
;
import
{
EnvironmentInformation
,
EnvironmentService
,
NodeInfomation
,
RunnerSettings
}
from
'
./environment
'
;
import
{
MountedStorageService
}
from
'
./storages/mountedStorageService
'
;
import
{
StorageService
}
from
'
./storageService
'
;
import
{
StorageService
}
from
'
./storageService
'
;
import
{
TrialDetail
}
from
'
./trial
'
;
import
{
TrialDetail
}
from
'
./trial
'
;
...
@@ -40,6 +40,7 @@ class TrialDispatcher implements TrainingService {
...
@@ -40,6 +40,7 @@ class TrialDispatcher implements TrainingService {
private
readonly
metricsEmitter
:
EventEmitter
;
private
readonly
metricsEmitter
:
EventEmitter
;
private
readonly
experimentId
:
string
;
private
readonly
experimentId
:
string
;
private
readonly
experimentRootDir
:
string
;
private
enableVersionCheck
:
boolean
=
true
;
private
enableVersionCheck
:
boolean
=
true
;
...
@@ -58,6 +59,8 @@ class TrialDispatcher implements TrainingService {
...
@@ -58,6 +59,8 @@ class TrialDispatcher implements TrainingService {
this
.
environments
=
new
Map
<
string
,
EnvironmentInformation
>
();
this
.
environments
=
new
Map
<
string
,
EnvironmentInformation
>
();
this
.
metricsEmitter
=
new
EventEmitter
();
this
.
metricsEmitter
=
new
EventEmitter
();
this
.
experimentId
=
getExperimentId
();
this
.
experimentId
=
getExperimentId
();
this
.
experimentRootDir
=
getExperimentRootDir
();
this
.
runnerSettings
=
new
RunnerSettings
();
this
.
runnerSettings
=
new
RunnerSettings
();
this
.
runnerSettings
.
experimentId
=
this
.
experimentId
;
this
.
runnerSettings
.
experimentId
=
this
.
experimentId
;
this
.
runnerSettings
.
platform
=
getPlatform
();
this
.
runnerSettings
.
platform
=
getPlatform
();
...
@@ -158,14 +161,14 @@ class TrialDispatcher implements TrainingService {
...
@@ -158,14 +161,14 @@ class TrialDispatcher implements TrainingService {
const
environmentService
=
component
.
get
<
EnvironmentService
>
(
EnvironmentService
);
const
environmentService
=
component
.
get
<
EnvironmentService
>
(
EnvironmentService
);
this
.
commandEmitter
=
new
EventEmitter
();
this
.
commandEmitter
=
new
EventEmitter
();
this
.
commandChannel
=
new
Web
CommandChannel
(
this
.
commandEmitter
);
this
.
commandChannel
=
environmentService
.
get
CommandChannel
(
this
.
commandEmitter
);
// TODO it's a hard code of web channel, it needs to be improved.
// TODO it's a hard code of web channel, it needs to be improved.
this
.
runnerSettings
.
nniManagerPort
=
getBasePort
()
+
1
;
this
.
runnerSettings
.
nniManagerPort
=
getBasePort
()
+
1
;
this
.
runnerSettings
.
commandChannel
=
this
.
commandChannel
.
channelName
;
this
.
runnerSettings
.
commandChannel
=
this
.
commandChannel
.
channelName
;
// for AML channel, other channels can ignore this.
// for AML channel, other channels can ignore this.
this
.
commandChannel
.
config
(
"
MetricEmitter
"
,
this
.
metricsEmitter
);
await
this
.
commandChannel
.
config
(
"
MetricEmitter
"
,
this
.
metricsEmitter
);
// start channel
// start channel
this
.
commandEmitter
.
on
(
"
command
"
,
(
command
:
Command
):
void
=>
{
this
.
commandEmitter
.
on
(
"
command
"
,
(
command
:
Command
):
void
=>
{
...
@@ -173,41 +176,50 @@ class TrialDispatcher implements TrainingService {
...
@@ -173,41 +176,50 @@ class TrialDispatcher implements TrainingService {
this
.
log
.
error
(
`TrialDispatcher: error on handle env
${
command
.
environment
.
id
}
command:
${
command
.
command
}
, data:
${
command
.
data
}
, error:
${
err
}
`
);
this
.
log
.
error
(
`TrialDispatcher: error on handle env
${
command
.
environment
.
id
}
command:
${
command
.
command
}
, data:
${
command
.
data
}
, error:
${
err
}
`
);
})
})
});
});
this
.
commandChannel
.
start
();
await
this
.
commandChannel
.
start
();
this
.
log
.
info
(
`TrialDispatcher: started channel:
${
this
.
commandChannel
.
constructor
.
name
}
`
);
this
.
log
.
info
(
`TrialDispatcher: started channel:
${
this
.
commandChannel
.
constructor
.
name
}
`
);
if
(
this
.
trialConfig
===
undefined
)
{
if
(
this
.
trialConfig
===
undefined
)
{
throw
new
Error
(
`trial config shouldn't be undefined in run()`
);
throw
new
Error
(
`trial config shouldn't be undefined in run()`
);
}
}
this
.
log
.
info
(
`TrialDispatcher: copying code and settings.`
);
let
storageService
:
StorageService
;
if
(
environmentService
.
hasStorageService
)
{
if
(
environmentService
.
hasStorageService
)
{
this
.
log
.
info
(
`TrialDispatcher: copying code and settings.`
);
this
.
log
.
debug
(
`TrialDispatcher: use existing storage service.`
);
const
storageService
=
component
.
get
<
StorageService
>
(
StorageService
);
storageService
=
component
.
get
<
StorageService
>
(
StorageService
);
// Copy the compressed file to remoteDirectory and delete it
}
else
{
const
codeDir
=
path
.
resolve
(
this
.
trialConfig
.
codeDir
);
this
.
log
.
debug
(
`TrialDispatcher: create temp storage service to temp folder.`
);
const
envDir
=
storageService
.
joinPath
(
"
envs
"
);
storageService
=
new
MountedStorageService
();
const
codeFileName
=
await
storageService
.
copyDirectory
(
codeDir
,
envDir
,
true
);
const
environmentLocalTempFolder
=
path
.
join
(
this
.
experimentRootDir
,
this
.
experimentId
,
"
environment-temp
"
);
storageService
.
rename
(
codeFileName
,
"
nni-code.tar.gz
"
);
storageService
.
initialize
(
this
.
trialConfig
.
codeDir
,
environmentLocalTempFolder
);
}
const
installFileName
=
storageService
.
joinPath
(
envDir
,
'
install_nni.sh
'
);
await
storageService
.
save
(
CONTAINER_INSTALL_NNI_SHELL_FORMAT
,
installFileName
);
// Copy the compressed file to remoteDirectory and delete it
const
codeDir
=
path
.
resolve
(
this
.
trialConfig
.
codeDir
);
const
runnerSettings
=
storageService
.
joinPath
(
envDir
,
"
settings.json
"
);
const
envDir
=
storageService
.
joinPath
(
"
envs
"
);
await
storageService
.
save
(
JSON
.
stringify
(
this
.
runnerSettings
),
runnerSettings
);
const
codeFileName
=
await
storageService
.
copyDirectory
(
codeDir
,
envDir
,
true
);
storageService
.
rename
(
codeFileName
,
"
nni-code.tar.gz
"
);
if
(
this
.
isDeveloping
)
{
let
trialToolsPath
=
path
.
join
(
__dirname
,
"
../../../../../tools/nni_trial_tool
"
);
const
installFileName
=
storageService
.
joinPath
(
envDir
,
'
install_nni.sh
'
);
if
(
false
===
fs
.
existsSync
(
trialToolsPath
))
{
await
storageService
.
save
(
CONTAINER_INSTALL_NNI_SHELL_FORMAT
,
installFileName
);
trialToolsPath
=
path
.
join
(
__dirname
,
"
..
\\
..
\\
..
\\
..
\\
..
\\
tools
\\
nni_trial_tool
"
);
}
const
runnerSettings
=
storageService
.
joinPath
(
envDir
,
"
settings.json
"
);
await
storageService
.
copyDirectory
(
trialToolsPath
,
envDir
,
true
);
await
storageService
.
save
(
JSON
.
stringify
(
this
.
runnerSettings
),
runnerSettings
);
if
(
this
.
isDeveloping
)
{
let
trialToolsPath
=
path
.
join
(
__dirname
,
"
../../../../../tools/nni_trial_tool
"
);
if
(
false
===
fs
.
existsSync
(
trialToolsPath
))
{
trialToolsPath
=
path
.
join
(
__dirname
,
"
..
\\
..
\\
..
\\
..
\\
..
\\
tools
\\
nni_trial_tool
"
);
}
}
await
storageService
.
copyDirectory
(
trialToolsPath
,
envDir
,
true
);
}
}
this
.
log
.
info
(
`TrialDispatcher: run loop started.`
);
this
.
log
.
info
(
`TrialDispatcher: run loop started.`
);
await
Promise
.
all
([
await
Promise
.
all
([
this
.
environmentMaintenanceLoop
(),
this
.
environmentMaintenanceLoop
(),
this
.
trialManagementLoop
(),
this
.
trialManagementLoop
(),
this
.
commandChannel
.
run
(),
]);
]);
}
}
...
@@ -274,7 +286,7 @@ class TrialDispatcher implements TrainingService {
...
@@ -274,7 +286,7 @@ class TrialDispatcher implements TrainingService {
}
}
this
.
commandEmitter
.
off
(
"
command
"
,
this
.
handleCommand
);
this
.
commandEmitter
.
off
(
"
command
"
,
this
.
handleCommand
);
this
.
commandChannel
.
stop
();
await
this
.
commandChannel
.
stop
();
}
}
private
async
environmentMaintenanceLoop
():
Promise
<
void
>
{
private
async
environmentMaintenanceLoop
():
Promise
<
void
>
{
...
@@ -396,7 +408,6 @@ class TrialDispatcher implements TrainingService {
...
@@ -396,7 +408,6 @@ class TrialDispatcher implements TrainingService {
break
;
break
;
}
}
}
}
let
liveEnvironmentsCount
=
0
;
let
liveEnvironmentsCount
=
0
;
const
idleEnvironments
:
EnvironmentInformation
[]
=
[];
const
idleEnvironments
:
EnvironmentInformation
[]
=
[];
this
.
environments
.
forEach
((
environment
)
=>
{
this
.
environments
.
forEach
((
environment
)
=>
{
...
@@ -407,7 +418,6 @@ class TrialDispatcher implements TrainingService {
...
@@ -407,7 +418,6 @@ class TrialDispatcher implements TrainingService {
}
}
}
}
});
});
while
(
idleEnvironments
.
length
>
0
&&
waitingTrials
.
length
>
0
)
{
while
(
idleEnvironments
.
length
>
0
&&
waitingTrials
.
length
>
0
)
{
const
trial
=
waitingTrials
.
shift
();
const
trial
=
waitingTrials
.
shift
();
const
idleEnvironment
=
idleEnvironments
.
shift
();
const
idleEnvironment
=
idleEnvironments
.
shift
();
...
@@ -442,14 +452,10 @@ class TrialDispatcher implements TrainingService {
...
@@ -442,14 +452,10 @@ class TrialDispatcher implements TrainingService {
environment
.
command
=
"
[ -d
\"
nni_trial_tool
\"
] && echo
\"
nni_trial_tool exists already
\"
|| (mkdir ./nni_trial_tool && tar -xof ../nni_trial_tool.tar.gz -C ./nni_trial_tool) && pip3 install websockets &&
"
+
environment
.
command
;
environment
.
command
=
"
[ -d
\"
nni_trial_tool
\"
] && echo
\"
nni_trial_tool exists already
\"
|| (mkdir ./nni_trial_tool && tar -xof ../nni_trial_tool.tar.gz -C ./nni_trial_tool) && pip3 install websockets &&
"
+
environment
.
command
;
}
}
if
(
environmentService
.
hasStorageService
)
{
environment
.
command
=
`mkdir -p envs/
${
envId
}
&& cd envs/
${
envId
}
&&
${
environment
.
command
}
`
;
const
storageService
=
component
.
get
<
StorageService
>
(
StorageService
);
environment
.
workingFolder
=
storageService
.
joinPath
(
"
envs
"
,
envId
);
await
storageService
.
createDirectory
(
environment
.
workingFolder
);
}
this
.
environments
.
set
(
environment
.
id
,
environment
);
await
environmentService
.
startEnvironment
(
environment
);
await
environmentService
.
startEnvironment
(
environment
);
this
.
environments
.
set
(
environment
.
id
,
environment
);
if
(
environment
.
status
===
"
FAILED
"
)
{
if
(
environment
.
status
===
"
FAILED
"
)
{
environment
.
isIdle
=
false
;
environment
.
isIdle
=
false
;
...
...
src/sdk/pynni/nni/platform/__init__.py
View file @
93f96d4f
...
@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
...
@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from
.standalone
import
*
from
.standalone
import
*
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
from
.test
import
*
from
.test
import
*
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
):
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
):
from
.local
import
*
from
.local
import
*
else
:
else
:
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
tools/nni_cmd/config_schema.py
View file @
93f96d4f
...
@@ -116,7 +116,7 @@ common_schema = {
...
@@ -116,7 +116,7 @@ common_schema = {
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
),
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
...
@@ -234,6 +234,23 @@ dlts_config_schema = {
...
@@ -234,6 +234,23 @@ dlts_config_schema = {
}
}
}
}
aml_trial_schema
=
{
'trial'
:{
'codeDir'
:
setPathCheck
(
'codeDir'
),
'command'
:
setType
(
'command'
,
str
),
'image'
:
setType
(
'image'
,
str
),
'computeTarget'
:
setType
(
'computeTarget'
,
str
)
}
}
aml_config_schema
=
{
'amlConfig'
:
{
'subscriptionId'
:
setType
(
'subscriptionId'
,
str
),
'resourceGroup'
:
setType
(
'resourceGroup'
,
str
),
'workspaceName'
:
setType
(
'workspaceName'
,
str
),
}
}
kubeflow_trial_schema
=
{
kubeflow_trial_schema
=
{
'trial'
:{
'trial'
:{
'codeDir'
:
setPathCheck
(
'codeDir'
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
...
@@ -374,6 +391,7 @@ training_service_schema_dict = {
...
@@ -374,6 +391,7 @@ training_service_schema_dict = {
'paiYarn'
:
Schema
({
**
common_schema
,
**
pai_yarn_trial_schema
,
**
pai_yarn_config_schema
}),
'paiYarn'
:
Schema
({
**
common_schema
,
**
pai_yarn_trial_schema
,
**
pai_yarn_config_schema
}),
'kubeflow'
:
Schema
({
**
common_schema
,
**
kubeflow_trial_schema
,
**
kubeflow_config_schema
}),
'kubeflow'
:
Schema
({
**
common_schema
,
**
kubeflow_trial_schema
,
**
kubeflow_config_schema
}),
'frameworkcontroller'
:
Schema
({
**
common_schema
,
**
frameworkcontroller_trial_schema
,
**
frameworkcontroller_config_schema
}),
'frameworkcontroller'
:
Schema
({
**
common_schema
,
**
frameworkcontroller_trial_schema
,
**
frameworkcontroller_config_schema
}),
'aml'
:
Schema
({
**
common_schema
,
**
aml_trial_schema
,
**
aml_config_schema
}),
'dlts'
:
Schema
({
**
common_schema
,
**
dlts_trial_schema
,
**
dlts_config_schema
}),
'dlts'
:
Schema
({
**
common_schema
,
**
dlts_trial_schema
,
**
dlts_config_schema
}),
}
}
...
...
tools/nni_cmd/launcher.py
View file @
93f96d4f
...
@@ -272,6 +272,25 @@ def set_dlts_config(experiment_config, port, config_file_name):
...
@@ -272,6 +272,25 @@ def set_dlts_config(experiment_config, port, config_file_name):
#set trial_config
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_aml_config
(
experiment_config
,
port
,
config_file_name
):
'''set aml configuration'''
aml_config_data
=
dict
()
aml_config_data
[
'aml_config'
]
=
experiment_config
[
'amlConfig'
]
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
aml_config_data
),
REST_TIME_OUT
)
err_message
=
None
if
not
response
or
not
response
.
status_code
==
200
:
if
response
is
not
None
:
err_message
=
response
.
text
_
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
return
result
,
message
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_experiment
(
experiment_config
,
mode
,
port
,
config_file_name
):
def
set_experiment
(
experiment_config
,
mode
,
port
,
config_file_name
):
'''Call startExperiment (rest POST /experiment) with yaml file content'''
'''Call startExperiment (rest POST /experiment) with yaml file content'''
request_data
=
dict
()
request_data
=
dict
()
...
@@ -374,6 +393,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
...
@@ -374,6 +393,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result
,
err_msg
=
set_frameworkcontroller_config
(
experiment_config
,
port
,
config_file_name
)
config_result
,
err_msg
=
set_frameworkcontroller_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'dlts'
:
elif
platform
==
'dlts'
:
config_result
,
err_msg
=
set_dlts_config
(
experiment_config
,
port
,
config_file_name
)
config_result
,
err_msg
=
set_dlts_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'aml'
:
config_result
,
err_msg
=
set_aml_config
(
experiment_config
,
port
,
config_file_name
)
else
:
else
:
raise
Exception
(
ERROR_INFO
%
'Unsupported platform!'
)
raise
Exception
(
ERROR_INFO
%
'Unsupported platform!'
)
exit
(
1
)
exit
(
1
)
...
...
tools/nni_trial_tool/aml_channel.py
0 → 100644
View file @
93f96d4f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
azureml.core.run
import
Run
# pylint: disable=import-error
from
.base_channel
import
BaseChannel
from
.log_utils
import
LogType
,
nni_log
class
AMLChannel
(
BaseChannel
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
run
=
Run
.
get_context
()
super
(
AMLChannel
,
self
).
__init__
(
args
)
self
.
current_message_index
=
-
1
def
_inner_open
(
self
):
pass
def
_inner_close
(
self
):
pass
def
_inner_send
(
self
,
message
):
try
:
self
.
run
.
log
(
'trial_runner'
,
message
.
decode
(
'utf8'
))
except
Exception
as
exception
:
nni_log
(
LogType
.
Error
,
'meet unhandled exception when send message: %s'
%
exception
)
def
_inner_receive
(
self
):
messages
=
[]
message_dict
=
self
.
run
.
get_metrics
()
if
'nni_manager'
not
in
message_dict
:
return
[]
message_list
=
message_dict
[
'nni_manager'
]
if
not
message_list
:
return
messages
if
type
(
message_list
)
is
list
:
if
self
.
current_message_index
<
len
(
message_list
)
-
1
:
messages
=
message_list
[
self
.
current_message_index
+
1
:
len
(
message_list
)]
self
.
current_message_index
=
len
(
message_list
)
-
1
elif
self
.
current_message_index
==
-
1
:
messages
=
[
message_list
]
self
.
current_message_index
+=
1
newMessage
=
[]
for
message
in
messages
:
# receive message is string, to get consistent result, encode it here.
newMessage
.
append
(
message
.
encode
(
'utf8'
))
return
newMessage
tools/nni_trial_tool/trial_runner.py
View file @
93f96d4f
...
@@ -210,6 +210,9 @@ if __name__ == '__main__':
...
@@ -210,6 +210,9 @@ if __name__ == '__main__':
command_channel
=
None
command_channel
=
None
if
args
.
command_channel
==
"file"
:
if
args
.
command_channel
==
"file"
:
command_channel
=
FileChannel
(
args
)
command_channel
=
FileChannel
(
args
)
elif
args
.
command_channel
==
'aml'
:
from
.aml_channel
import
AMLChannel
command_channel
=
AMLChannel
(
args
)
else
:
else
:
command_channel
=
WebChannel
(
args
)
command_channel
=
WebChannel
(
args
)
command_channel
.
open
()
command_channel
.
open
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment