Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
93f96d4f
Unverified
Commit
93f96d4f
authored
Jul 01, 2020
by
SparkSnail
Committed by
GitHub
Jul 01, 2020
Browse files
Support aml (#2615)
parent
f5caa193
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
660 additions
and
9 deletions
+660
-9
docs/en_US/TrainingService/AMLMode.md
docs/en_US/TrainingService/AMLMode.md
+66
-0
docs/en_US/training_services.rst
docs/en_US/training_services.rst
+1
-0
docs/img/aml_account.png
docs/img/aml_account.png
+0
-0
examples/trials/mnist-pytorch/config_aml.yml
examples/trials/mnist-pytorch/config_aml.yml
+25
-0
examples/trials/mnist-tfv1/config_aml.yml
examples/trials/mnist-tfv1/config_aml.yml
+25
-0
src/nni_manager/config/aml/amlUtil.py
src/nni_manager/config/aml/amlUtil.py
+56
-0
src/nni_manager/main.ts
src/nni_manager/main.ts
+6
-2
src/nni_manager/package.json
src/nni_manager/package.json
+1
-0
src/nni_manager/rest_server/restValidationSchemas.ts
src/nni_manager/rest_server/restValidationSchemas.ts
+7
-0
src/nni_manager/training_service/common/trialConfigMetadataKey.ts
...manager/training_service/common/trialConfigMetadataKey.ts
+1
-0
src/nni_manager/training_service/reusable/aml/amlClient.ts
src/nni_manager/training_service/reusable/aml/amlClient.ts
+125
-0
src/nni_manager/training_service/reusable/aml/amlConfig.ts
src/nni_manager/training_service/reusable/aml/amlConfig.ts
+39
-0
src/nni_manager/training_service/reusable/channels/amlCommandChannel.ts
...r/training_service/reusable/channels/amlCommandChannel.ts
+120
-0
src/nni_manager/training_service/reusable/channels/fileCommandChannel.ts
.../training_service/reusable/channels/fileCommandChannel.ts
+10
-4
src/nni_manager/training_service/reusable/channels/webCommandChannel.ts
...r/training_service/reusable/channels/webCommandChannel.ts
+4
-0
src/nni_manager/training_service/reusable/commandChannel.ts
src/nni_manager/training_service/reusable/commandChannel.ts
+3
-0
src/nni_manager/training_service/reusable/environment.ts
src/nni_manager/training_service/reusable/environment.ts
+1
-1
src/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts
...ng_service/reusable/environments/amlEnvironmentService.ts
+147
-0
src/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts
...ervice/reusable/environments/openPaiEnvironmentService.ts
+3
-2
src/nni_manager/training_service/reusable/routerTrainingService.ts
...anager/training_service/reusable/routerTrainingService.ts
+20
-0
No files found.
docs/en_US/TrainingService/AMLMode.md
0 → 100644
View file @
93f96d4f
**Run an Experiment on Azure Machine Learning**
===
NNI supports running an experiment on
[
AML
](
https://azure.microsoft.com/en-us/services/machine-learning/
)
, called aml mode.
## Setup environment
Step 1. Install NNI, follow the install guide
[
here
](
../Tutorial/QuickStart.md
)
.
Step 2. Create AML account, follow the document
[
here
](
https://docs.microsoft.com/en-us/azure/machine-learning/how-to-manage-workspace-cli
)
.
Step 3. Get your account information.

Step4. Install AML package environment.
```
python3 -m pip install azureml --user
python3 -m pip install azureml-sdk --user
```
## Run an experiment
Use
`examples/trials/mnist-tfv1`
as an example. The NNI config YAML file's content is like:
```
yaml
authorName
:
default
experimentName
:
example_mnist
trialConcurrency
:
1
maxExecDuration
:
1h
maxTrialNum
:
10
trainingServicePlatform
:
aml
searchSpacePath
:
search_space.json
#choice: true, false
useAnnotation
:
false
tuner
:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName
:
TPE
classArgs
:
#choice: maximize, minimize
optimize_mode
:
maximize
trial
:
command
:
python3 mnist.py
codeDir
:
.
computeTarget
:
${replace_to_your_computeTarget}
image
:
msranni/nni
amlConfig
:
subscriptionId
:
${replace_to_your_subscriptionId}
resourceGroup
:
${replace_to_your_resourceGroup}
workspaceName
:
${replace_to_your_workspaceName}
```
Note: You should set
`trainingServicePlatform: aml`
in NNI config YAML file if you want to start experiment in aml mode.
Compared with
[
LocalMode
](
LocalMode.md
)
trial configuration in aml mode have these additional keys:
*
computeTarget
*
required key. The computer cluster name you want to use in your AML workspace.
*
image
*
required key. The docker image name used in job.
amlConfig:
*
subscriptionId
*
the subscriptionId of your account
*
resourceGroup
*
the resourceGroup of your account
*
workspaceName
*
the workspaceName of your account
\ No newline at end of file
docs/en_US/training_services.rst
View file @
93f96d4f
...
@@ -10,3 +10,4 @@ Introduction to NNI Training Services
...
@@ -10,3 +10,4 @@ Introduction to NNI Training Services
Kubeflow<./TrainingService/KubeflowMode>
Kubeflow<./TrainingService/KubeflowMode>
FrameworkController<./TrainingService/FrameworkControllerMode>
FrameworkController<./TrainingService/FrameworkControllerMode>
DLTS<./TrainingService/DLTSMode>
DLTS<./TrainingService/DLTSMode>
AML<./TrainingService/AMLMode>
docs/img/aml_account.png
0 → 100644
View file @
93f96d4f
24.7 KB
examples/trials/mnist-pytorch/config_aml.yml
0 → 100644
View file @
93f96d4f
authorName
:
default
experimentName
:
example_mnist_pytorch
trialConcurrency
:
1
maxExecDuration
:
1h
maxTrialNum
:
10
trainingServicePlatform
:
aml
searchSpacePath
:
search_space.json
#choice: true, false
useAnnotation
:
false
tuner
:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName
:
TPE
classArgs
:
#choice: maximize, minimize
optimize_mode
:
maximize
trial
:
command
:
python3 mnist.py
codeDir
:
.
computeTarget
:
${replace_to_your_computeTarget}
image
:
msranni/nni
amlConfig
:
subscriptionId
:
${replace_to_your_subscriptionId}
resourceGroup
:
${replace_to_your_resourceGroup}
workspaceName
:
${replace_to_your_workspaceName}
examples/trials/mnist-tfv1/config_aml.yml
0 → 100644
View file @
93f96d4f
authorName
:
default
experimentName
:
example_mnist
trialConcurrency
:
1
maxExecDuration
:
1h
maxTrialNum
:
10
trainingServicePlatform
:
aml
searchSpacePath
:
search_space.json
#choice: true, false
useAnnotation
:
false
tuner
:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName
:
TPE
classArgs
:
#choice: maximize, minimize
optimize_mode
:
maximize
trial
:
command
:
python3 mnist.py
codeDir
:
.
computeTarget
:
${replace_to_your_computeTarget}
image
:
msranni/nni
amlConfig
:
subscriptionId
:
${replace_to_your_subscriptionId}
resourceGroup
:
${replace_to_your_resourceGroup}
workspaceName
:
${replace_to_your_workspaceName}
src/nni_manager/config/aml/amlUtil.py
0 → 100644
View file @
93f96d4f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
sys
import
time
import
json
from
argparse
import
ArgumentParser
from
azureml.core
import
Experiment
,
RunConfiguration
,
ScriptRunConfig
from
azureml.core.compute
import
ComputeTarget
from
azureml.core.run
import
RUNNING_STATES
,
RunStatus
,
Run
from
azureml.core
import
Workspace
from
azureml.core.conda_dependencies
import
CondaDependencies
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
()
parser
.
add_argument
(
'--subscription_id'
,
help
=
'the subscription id of aml'
)
parser
.
add_argument
(
'--resource_group'
,
help
=
'the resource group of aml'
)
parser
.
add_argument
(
'--workspace_name'
,
help
=
'the workspace name of aml'
)
parser
.
add_argument
(
'--compute_target'
,
help
=
'the compute cluster name of aml'
)
parser
.
add_argument
(
'--docker_image'
,
help
=
'the docker image of job'
)
parser
.
add_argument
(
'--experiment_name'
,
help
=
'the experiment name'
)
parser
.
add_argument
(
'--script_dir'
,
help
=
'script directory'
)
parser
.
add_argument
(
'--script_name'
,
help
=
'script name'
)
args
=
parser
.
parse_args
()
ws
=
Workspace
(
args
.
subscription_id
,
args
.
resource_group
,
args
.
workspace_name
)
compute_target
=
ComputeTarget
(
workspace
=
ws
,
name
=
args
.
compute_target
)
experiment
=
Experiment
(
ws
,
args
.
experiment_name
)
run_config
=
RunConfiguration
()
dependencies
=
CondaDependencies
()
dependencies
.
add_pip_package
(
"azureml-sdk"
)
dependencies
.
add_pip_package
(
"azureml"
)
run_config
.
environment
.
python
.
conda_dependencies
=
dependencies
run_config
.
environment
.
docker
.
enabled
=
True
run_config
.
environment
.
docker
.
base_image
=
args
.
docker_image
run_config
.
target
=
compute_target
run_config
.
node_count
=
1
config
=
ScriptRunConfig
(
source_directory
=
args
.
script_dir
,
script
=
args
.
script_name
,
run_config
=
run_config
)
run
=
experiment
.
submit
(
config
)
print
(
run
.
get_details
()[
"runId"
])
while
True
:
line
=
sys
.
stdin
.
readline
().
rstrip
()
if
line
==
'update_status'
:
print
(
'status:'
+
run
.
get_status
())
elif
line
==
'tracking_url'
:
print
(
'tracking_url:'
+
run
.
get_portal_url
())
elif
line
==
'stop'
:
run
.
cancel
()
exit
(
0
)
elif
line
==
'receive'
:
print
(
'receive:'
+
json
.
dumps
(
run
.
get_metrics
()))
elif
line
:
items
=
line
.
split
(
':'
)
if
items
[
0
]
==
'command'
:
run
.
log
(
'nni_manager'
,
line
[
8
:])
src/nni_manager/main.ts
View file @
93f96d4f
...
@@ -65,6 +65,10 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
...
@@ -65,6 +65,10 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
Container
.
bind
(
TrainingService
)
Container
.
bind
(
TrainingService
)
.
to
(
DLTSTrainingService
)
.
to
(
DLTSTrainingService
)
.
scope
(
Scope
.
Singleton
);
.
scope
(
Scope
.
Singleton
);
}
else
if
(
platformMode
===
'
aml
'
)
{
Container
.
bind
(
TrainingService
)
.
to
(
RouterTrainingService
)
.
scope
(
Scope
.
Singleton
);
}
else
{
}
else
{
throw
new
Error
(
`Error: unsupported mode:
${
platformMode
}
`
);
throw
new
Error
(
`Error: unsupported mode:
${
platformMode
}
`
);
}
}
...
@@ -93,7 +97,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
...
@@ -93,7 +97,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
function
usage
():
void
{
function
usage
():
void
{
console
.
info
(
'
usage: node main.js --port <port> --mode
\
console
.
info
(
'
usage: node main.js --port <port> --mode
\
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>
'
);
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn
/aml
> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>
'
);
}
}
const
strPort
:
string
=
parseArg
([
'
--port
'
,
'
-p
'
]);
const
strPort
:
string
=
parseArg
([
'
--port
'
,
'
-p
'
]);
...
@@ -113,7 +117,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
...
@@ -113,7 +117,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
const
port
:
number
=
parseInt
(
strPort
,
10
);
const
port
:
number
=
parseInt
(
strPort
,
10
);
const
mode
:
string
=
parseArg
([
'
--mode
'
,
'
-m
'
]);
const
mode
:
string
=
parseArg
([
'
--mode
'
,
'
-m
'
]);
if
(
!
[
'
local
'
,
'
remote
'
,
'
pai
'
,
'
kubeflow
'
,
'
frameworkcontroller
'
,
'
paiYarn
'
,
'
dlts
'
].
includes
(
mode
))
{
if
(
!
[
'
local
'
,
'
remote
'
,
'
pai
'
,
'
kubeflow
'
,
'
frameworkcontroller
'
,
'
paiYarn
'
,
'
dlts
'
,
'
aml
'
].
includes
(
mode
))
{
console
.
log
(
`FATAL: unknown mode:
${
mode
}
`
);
console
.
log
(
`FATAL: unknown mode:
${
mode
}
`
);
usage
();
usage
();
process
.
exit
(
1
);
process
.
exit
(
1
);
...
...
src/nni_manager/package.json
View file @
93f96d4f
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
"ignore"
:
"^5.1.4"
,
"ignore"
:
"^5.1.4"
,
"js-base64"
:
"^2.4.9"
,
"js-base64"
:
"^2.4.9"
,
"kubernetes-client"
:
"^6.5.0"
,
"kubernetes-client"
:
"^6.5.0"
,
"python-shell"
:
"^2.0.1"
,
"rx"
:
"^4.1.0"
,
"rx"
:
"^4.1.0"
,
"sqlite3"
:
"^4.0.2"
,
"sqlite3"
:
"^4.0.2"
,
"ssh2"
:
"^0.6.1"
,
"ssh2"
:
"^0.6.1"
,
...
...
src/nni_manager/rest_server/restValidationSchemas.ts
View file @
93f96d4f
...
@@ -39,6 +39,8 @@ export namespace ValidationSchemas {
...
@@ -39,6 +39,8 @@ export namespace ValidationSchemas {
nniManagerNFSMountPath
:
joi
.
string
().
min
(
1
),
nniManagerNFSMountPath
:
joi
.
string
().
min
(
1
),
containerNFSMountPath
:
joi
.
string
().
min
(
1
),
containerNFSMountPath
:
joi
.
string
().
min
(
1
),
paiConfigPath
:
joi
.
string
(),
paiConfigPath
:
joi
.
string
(),
computeTarget
:
joi
.
string
(),
nodeCount
:
joi
.
number
(),
paiStorageConfigName
:
joi
.
string
().
min
(
1
),
paiStorageConfigName
:
joi
.
string
().
min
(
1
),
nasMode
:
joi
.
string
().
valid
(
'
classic_mode
'
,
'
enas_mode
'
,
'
oneshot_mode
'
,
'
darts_mode
'
),
nasMode
:
joi
.
string
().
valid
(
'
classic_mode
'
,
'
enas_mode
'
,
'
oneshot_mode
'
,
'
darts_mode
'
),
portList
:
joi
.
array
().
items
(
joi
.
object
({
portList
:
joi
.
array
().
items
(
joi
.
object
({
...
@@ -150,6 +152,11 @@ export namespace ValidationSchemas {
...
@@ -150,6 +152,11 @@ export namespace ValidationSchemas {
email
:
joi
.
string
().
min
(
1
),
email
:
joi
.
string
().
min
(
1
),
password
:
joi
.
string
().
min
(
1
)
password
:
joi
.
string
().
min
(
1
)
}),
}),
aml_config
:
joi
.
object
({
// eslint-disable-line @typescript-eslint/camelcase
subscriptionId
:
joi
.
string
().
min
(
1
),
resourceGroup
:
joi
.
string
().
min
(
1
),
workspaceName
:
joi
.
string
().
min
(
1
)
}),
nni_manager_ip
:
joi
.
object
({
// eslint-disable-line @typescript-eslint/camelcase
nni_manager_ip
:
joi
.
object
({
// eslint-disable-line @typescript-eslint/camelcase
nniManagerIp
:
joi
.
string
().
min
(
1
)
nniManagerIp
:
joi
.
string
().
min
(
1
)
})
})
...
...
src/nni_manager/training_service/common/trialConfigMetadataKey.ts
View file @
93f96d4f
...
@@ -19,6 +19,7 @@ export enum TrialConfigMetadataKey {
...
@@ -19,6 +19,7 @@ export enum TrialConfigMetadataKey {
NNI_MANAGER_IP
=
'
nni_manager_ip
'
,
NNI_MANAGER_IP
=
'
nni_manager_ip
'
,
FRAMEWORKCONTROLLER_CLUSTER_CONFIG
=
'
frameworkcontroller_config
'
,
FRAMEWORKCONTROLLER_CLUSTER_CONFIG
=
'
frameworkcontroller_config
'
,
DLTS_CLUSTER_CONFIG
=
'
dlts_config
'
,
DLTS_CLUSTER_CONFIG
=
'
dlts_config
'
,
AML_CLUSTER_CONFIG
=
'
aml_config
'
,
VERSION_CHECK
=
'
version_check
'
,
VERSION_CHECK
=
'
version_check
'
,
LOG_COLLECTION
=
'
log_collection
'
LOG_COLLECTION
=
'
log_collection
'
}
}
src/nni_manager/training_service/reusable/aml/amlClient.ts
0 → 100644
View file @
93f96d4f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
PythonShell
}
from
'
python-shell
'
;
export
class
AMLClient
{
public
subscriptionId
:
string
;
public
resourceGroup
:
string
;
public
workspaceName
:
string
;
public
experimentId
:
string
;
public
image
:
string
;
public
scriptName
:
string
;
public
pythonShellClient
:
undefined
|
PythonShell
;
public
codeDir
:
string
;
public
computeTarget
:
string
;
constructor
(
subscriptionId
:
string
,
resourceGroup
:
string
,
workspaceName
:
string
,
experimentId
:
string
,
computeTarget
:
string
,
image
:
string
,
scriptName
:
string
,
codeDir
:
string
,
)
{
this
.
subscriptionId
=
subscriptionId
;
this
.
resourceGroup
=
resourceGroup
;
this
.
workspaceName
=
workspaceName
;
this
.
experimentId
=
experimentId
;
this
.
image
=
image
;
this
.
scriptName
=
scriptName
;
this
.
codeDir
=
codeDir
;
this
.
computeTarget
=
computeTarget
;
}
public
submit
():
Promise
<
string
>
{
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
this
.
pythonShellClient
=
new
PythonShell
(
'
amlUtil.py
'
,
{
scriptPath
:
'
./config/aml
'
,
pythonOptions
:
[
'
-u
'
],
// get print results in real-time
args
:
[
'
--subscription_id
'
,
this
.
subscriptionId
,
'
--resource_group
'
,
this
.
resourceGroup
,
'
--workspace_name
'
,
this
.
workspaceName
,
'
--compute_target
'
,
this
.
computeTarget
,
'
--docker_image
'
,
this
.
image
,
'
--experiment_name
'
,
`nni_exp_
${
this
.
experimentId
}
`
,
'
--script_dir
'
,
this
.
codeDir
,
'
--script_name
'
,
this
.
scriptName
]
});
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
envId
:
any
)
{
// received a message sent from the Python script (a simple "print" statement)
deferred
.
resolve
(
envId
);
});
return
deferred
.
promise
;
}
public
stop
():
void
{
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
}
this
.
pythonShellClient
.
send
(
'
stop
'
);
}
public
getTrackingUrl
():
Promise
<
string
>
{
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
}
this
.
pythonShellClient
.
send
(
'
tracking_url
'
);
let
trackingUrl
=
''
;
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
status
:
any
)
{
const
items
=
status
.
split
(
'
:
'
);
if
(
items
[
0
]
===
'
tracking_url
'
)
{
trackingUrl
=
items
.
splice
(
1
,
items
.
length
).
join
(
''
)
}
deferred
.
resolve
(
trackingUrl
);
});
return
deferred
.
promise
;
}
public
updateStatus
(
oldStatus
:
string
):
Promise
<
string
>
{
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
}
let
newStatus
=
oldStatus
;
this
.
pythonShellClient
.
send
(
'
update_status
'
);
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
status
:
any
)
{
const
items
=
status
.
split
(
'
:
'
);
if
(
items
[
0
]
===
'
status
'
)
{
newStatus
=
items
.
splice
(
1
,
items
.
length
).
join
(
''
)
}
deferred
.
resolve
(
newStatus
);
});
return
deferred
.
promise
;
}
public
sendCommand
(
message
:
string
):
void
{
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
}
this
.
pythonShellClient
.
send
(
`command:
${
message
}
`
);
}
public
receiveCommand
():
Promise
<
any
>
{
const
deferred
:
Deferred
<
any
>
=
new
Deferred
<
any
>
();
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
}
this
.
pythonShellClient
.
send
(
'
receive
'
);
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
command
:
any
)
{
const
items
=
command
.
split
(
'
:
'
)
if
(
items
[
0
]
===
'
receive
'
)
{
deferred
.
resolve
(
JSON
.
parse
(
command
.
slice
(
8
)))
}
});
return
deferred
.
promise
;
}
}
src/nni_manager/training_service/reusable/aml/amlConfig.ts
0 → 100644
View file @
93f96d4f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
{
TrialConfig
}
from
'
../../common/trialConfig
'
;
import
{
EnvironmentInformation
}
from
'
../environment
'
;
import
{
AMLClient
}
from
'
../aml/amlClient
'
;
export
class
AMLClusterConfig
{
public
readonly
subscriptionId
:
string
;
public
readonly
resourceGroup
:
string
;
public
readonly
workspaceName
:
string
;
constructor
(
subscriptionId
:
string
,
resourceGroup
:
string
,
workspaceName
:
string
)
{
this
.
subscriptionId
=
subscriptionId
;
this
.
resourceGroup
=
resourceGroup
;
this
.
workspaceName
=
workspaceName
;
}
}
export
class
AMLTrialConfig
extends
TrialConfig
{
public
readonly
image
:
string
;
public
readonly
command
:
string
;
public
readonly
codeDir
:
string
;
public
readonly
computeTarget
:
string
;
constructor
(
codeDir
:
string
,
command
:
string
,
image
:
string
,
computeTarget
:
string
)
{
super
(
""
,
codeDir
,
0
);
this
.
codeDir
=
codeDir
;
this
.
command
=
command
;
this
.
image
=
image
;
this
.
computeTarget
=
computeTarget
;
}
}
export
class
AMLEnvironmentInformation
extends
EnvironmentInformation
{
public
amlClient
?:
AMLClient
;
}
src/nni_manager/training_service/reusable/channels/amlCommandChannel.ts
0 → 100644
View file @
93f96d4f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
{
EventEmitter
}
from
'
events
'
;
import
{
delay
}
from
"
../../../common/utils
"
;
import
{
AMLEnvironmentInformation
}
from
'
../aml/amlConfig
'
;
import
{
CommandChannel
,
RunnerConnection
}
from
"
../commandChannel
"
;
import
{
Channel
,
EnvironmentInformation
}
from
"
../environment
"
;
class
AMLRunnerConnection
extends
RunnerConnection
{
}
export
class
AMLCommandChannel
extends
CommandChannel
{
private
stopping
:
boolean
=
false
;
private
currentMessageIndex
:
number
=
-
1
;
private
sendQueues
:
[
EnvironmentInformation
,
string
][]
=
[];
private
readonly
NNI_METRICS_PATTERN
:
string
=
`NNISDK_MEb'(?<metrics>.*?)'`
;
public
constructor
(
commandEmitter
:
EventEmitter
)
{
super
(
commandEmitter
);
}
public
get
channelName
():
Channel
{
return
"
aml
"
;
}
public
async
config
(
_key
:
string
,
_value
:
any
):
Promise
<
void
>
{
// do nothing
}
public
async
start
():
Promise
<
void
>
{
// do nothing
}
public
async
stop
():
Promise
<
void
>
{
this
.
stopping
=
true
;
}
public
async
run
():
Promise
<
void
>
{
// start command loops
await
Promise
.
all
([
this
.
receiveLoop
(),
this
.
sendLoop
()
]);
}
protected
async
sendCommandInternal
(
environment
:
EnvironmentInformation
,
message
:
string
):
Promise
<
void
>
{
this
.
sendQueues
.
push
([
environment
,
message
]);
}
protected
createRunnerConnection
(
environment
:
EnvironmentInformation
):
RunnerConnection
{
return
new
AMLRunnerConnection
(
environment
);
}
private
async
sendLoop
():
Promise
<
void
>
{
const
intervalSeconds
=
0.5
;
while
(
!
this
.
stopping
)
{
const
start
=
new
Date
();
if
(
this
.
sendQueues
.
length
>
0
)
{
while
(
this
.
sendQueues
.
length
>
0
)
{
const
item
=
this
.
sendQueues
.
shift
();
if
(
item
===
undefined
)
{
break
;
}
const
environment
=
item
[
0
];
const
message
=
item
[
1
];
const
amlClient
=
(
environment
as
AMLEnvironmentInformation
).
amlClient
;
if
(
!
amlClient
)
{
throw
new
Error
(
'
aml client not initialized!
'
);
}
amlClient
.
sendCommand
(
message
);
}
}
const
end
=
new
Date
();
const
delayMs
=
intervalSeconds
*
1000
-
(
end
.
valueOf
()
-
start
.
valueOf
());
if
(
delayMs
>
0
)
{
await
delay
(
delayMs
);
}
}
}
private
async
receiveLoop
():
Promise
<
void
>
{
const
intervalSeconds
=
2
;
while
(
!
this
.
stopping
)
{
const
start
=
new
Date
();
const
runnerConnections
=
[...
this
.
runnerConnections
.
values
()]
as
AMLRunnerConnection
[];
for
(
const
runnerConnection
of
runnerConnections
)
{
// to loop all commands
const
amlClient
=
(
runnerConnection
.
environment
as
AMLEnvironmentInformation
).
amlClient
;
if
(
!
amlClient
)
{
throw
new
Error
(
'
AML client not initialized!
'
);
}
const
command
=
await
amlClient
.
receiveCommand
();
if
(
command
&&
Object
.
prototype
.
hasOwnProperty
.
call
(
command
,
"
trial_runner
"
))
{
const
messages
=
command
[
'
trial_runner
'
];
if
(
messages
)
{
if
(
messages
instanceof
Object
&&
this
.
currentMessageIndex
<
messages
.
length
-
1
)
{
for
(
let
index
=
this
.
currentMessageIndex
+
1
;
index
<
messages
.
length
;
index
++
)
{
this
.
handleCommand
(
runnerConnection
.
environment
,
messages
[
index
]);
}
this
.
currentMessageIndex
=
messages
.
length
-
1
;
}
else
if
(
this
.
currentMessageIndex
===
-
1
){
this
.
handleCommand
(
runnerConnection
.
environment
,
messages
);
this
.
currentMessageIndex
+=
1
;
}
}
}
}
const
end
=
new
Date
();
const
delayMs
=
intervalSeconds
*
1000
-
(
end
.
valueOf
()
-
start
.
valueOf
());
if
(
delayMs
>
0
)
{
await
delay
(
delayMs
);
}
}
}
}
src/nni_manager/training_service/reusable/channels/fileCommandChannel.ts
View file @
93f96d4f
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
*
as
component
from
"
../../../common/component
"
;
import
*
as
component
from
"
../../../common/component
"
;
import
{
delay
}
from
"
../../../common/utils
"
;
import
{
delay
}
from
"
../../../common/utils
"
;
import
{
CommandChannel
,
RunnerConnection
}
from
"
../commandChannel
"
;
import
{
CommandChannel
,
RunnerConnection
}
from
"
../commandChannel
"
;
import
{
EnvironmentInformation
,
Channel
}
from
"
../environment
"
;
import
{
Channel
,
EnvironmentInformation
}
from
"
../environment
"
;
import
{
StorageService
}
from
"
../storageService
"
;
import
{
StorageService
}
from
"
../storageService
"
;
class
FileHandler
{
class
FileHandler
{
...
@@ -38,15 +38,21 @@ export class FileCommandChannel extends CommandChannel {
...
@@ -38,15 +38,21 @@ export class FileCommandChannel extends CommandChannel {
}
}
public
async
start
():
Promise
<
void
>
{
public
async
start
():
Promise
<
void
>
{
// start command loops
// do nothing
this
.
receiveLoop
();
this
.
sendLoop
();
}
}
public
async
stop
():
Promise
<
void
>
{
public
async
stop
():
Promise
<
void
>
{
this
.
stopping
=
true
;
this
.
stopping
=
true
;
}
}
public
async
run
():
Promise
<
void
>
{
// start command loops
await
Promise
.
all
([
this
.
receiveLoop
(),
this
.
sendLoop
()
]);
}
protected
async
sendCommandInternal
(
environment
:
EnvironmentInformation
,
message
:
string
):
Promise
<
void
>
{
protected
async
sendCommandInternal
(
environment
:
EnvironmentInformation
,
message
:
string
):
Promise
<
void
>
{
this
.
sendQueues
.
push
([
environment
,
message
]);
this
.
sendQueues
.
push
([
environment
,
message
]);
}
}
...
...
src/nni_manager/training_service/reusable/channels/webCommandChannel.ts
View file @
93f96d4f
...
@@ -66,6 +66,10 @@ export class WebCommandChannel extends CommandChannel {
...
@@ -66,6 +66,10 @@ export class WebCommandChannel extends CommandChannel {
}
}
}
}
public
async
run
():
Promise
<
void
>
{
// do nothing
}
protected
async
sendCommandInternal
(
environment
:
EnvironmentInformation
,
message
:
string
):
Promise
<
void
>
{
protected
async
sendCommandInternal
(
environment
:
EnvironmentInformation
,
message
:
string
):
Promise
<
void
>
{
if
(
this
.
webSocketServer
===
undefined
)
{
if
(
this
.
webSocketServer
===
undefined
)
{
throw
new
Error
(
`WebCommandChannel: uninitialized!`
)
throw
new
Error
(
`WebCommandChannel: uninitialized!`
)
...
...
src/nni_manager/training_service/reusable/commandChannel.ts
View file @
93f96d4f
...
@@ -59,6 +59,9 @@ export abstract class CommandChannel {
...
@@ -59,6 +59,9 @@ export abstract class CommandChannel {
public
abstract
start
():
Promise
<
void
>
;
public
abstract
start
():
Promise
<
void
>
;
public
abstract
stop
():
Promise
<
void
>
;
public
abstract
stop
():
Promise
<
void
>
;
// Pull-based command channels need loop to check messages, the loop should be started with await here.
public
abstract
run
():
Promise
<
void
>
;
protected
abstract
sendCommandInternal
(
environment
:
EnvironmentInformation
,
message
:
string
):
Promise
<
void
>
;
protected
abstract
sendCommandInternal
(
environment
:
EnvironmentInformation
,
message
:
string
):
Promise
<
void
>
;
protected
abstract
createRunnerConnection
(
environment
:
EnvironmentInformation
):
RunnerConnection
;
protected
abstract
createRunnerConnection
(
environment
:
EnvironmentInformation
):
RunnerConnection
;
...
...
src/nni_manager/training_service/reusable/environment.ts
View file @
93f96d4f
...
@@ -14,7 +14,6 @@ import { CommandChannel } from "./commandChannel";
...
@@ -14,7 +14,6 @@ import { CommandChannel } from "./commandChannel";
export
type
EnvironmentStatus
=
'
UNKNOWN
'
|
'
WAITING
'
|
'
RUNNING
'
|
'
SUCCEEDED
'
|
'
FAILED
'
|
'
USER_CANCELED
'
;
export
type
EnvironmentStatus
=
'
UNKNOWN
'
|
'
WAITING
'
|
'
RUNNING
'
|
'
SUCCEEDED
'
|
'
FAILED
'
|
'
USER_CANCELED
'
;
export
type
Channel
=
"
web
"
|
"
file
"
|
"
aml
"
|
"
ut
"
;
export
type
Channel
=
"
web
"
|
"
file
"
|
"
aml
"
|
"
ut
"
;
export
class
EnvironmentInformation
{
export
class
EnvironmentInformation
{
private
log
:
Logger
;
private
log
:
Logger
;
...
@@ -65,6 +64,7 @@ export class EnvironmentInformation {
...
@@ -65,6 +64,7 @@ export class EnvironmentInformation {
}
}
}
}
}
}
export
abstract
class
EnvironmentService
{
export
abstract
class
EnvironmentService
{
public
abstract
get
hasStorageService
():
boolean
;
public
abstract
get
hasStorageService
():
boolean
;
...
...
src/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts
0 → 100644
View file @
93f96d4f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
path
from
'
path
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
TrialConfigMetadataKey
}
from
'
../../common/trialConfigMetadataKey
'
;
import
{
AMLClusterConfig
,
AMLTrialConfig
}
from
'
../aml/amlConfig
'
;
import
{
EnvironmentInformation
,
EnvironmentService
}
from
'
../environment
'
;
import
{
AMLEnvironmentInformation
}
from
'
../aml/amlConfig
'
;
import
{
AMLClient
}
from
'
../aml/amlClient
'
;
import
{
NNIManagerIpConfig
,
}
from
'
../../../common/trainingService
'
;
import
{
validateCodeDir
}
from
'
../../common/util
'
;
import
{
getExperimentRootDir
}
from
'
../../../common/utils
'
;
import
{
AMLCommandChannel
}
from
'
../channels/amlCommandChannel
'
;
import
{
CommandChannel
}
from
"
../commandChannel
"
;
import
{
EventEmitter
}
from
"
events
"
;
/**
* Collector PAI jobs info from PAI cluster, and update pai job status locally
*/
@
component
.
Singleton
export
class
AMLEnvironmentService
extends
EnvironmentService
{
private
readonly
log
:
Logger
=
getLogger
();
public
amlClusterConfig
:
AMLClusterConfig
|
undefined
;
public
amlTrialConfig
:
AMLTrialConfig
|
undefined
;
private
amlJobConfig
:
any
;
private
stopping
:
boolean
=
false
;
private
versionCheck
:
boolean
=
true
;
private
isMultiPhase
:
boolean
=
false
;
private
nniVersion
?:
string
;
private
experimentId
:
string
;
private
nniManagerIpConfig
?:
NNIManagerIpConfig
;
private
experimentRootDir
:
string
;
constructor
()
{
super
();
this
.
experimentId
=
getExperimentId
();
this
.
experimentRootDir
=
getExperimentRootDir
();
}
public
get
hasStorageService
():
boolean
{
return
false
;
}
public
getCommandChannel
(
commandEmitter
:
EventEmitter
):
CommandChannel
{
return
new
AMLCommandChannel
(
commandEmitter
);
}
public
createEnviornmentInfomation
(
envId
:
string
,
envName
:
string
):
EnvironmentInformation
{
return
new
AMLEnvironmentInformation
(
envId
,
envName
);
}
public
async
config
(
key
:
string
,
value
:
string
):
Promise
<
void
>
{
switch
(
key
)
{
case
TrialConfigMetadataKey
.
AML_CLUSTER_CONFIG
:
this
.
amlClusterConfig
=
<
AMLClusterConfig
>
JSON
.
parse
(
value
);
break
;
case
TrialConfigMetadataKey
.
TRIAL_CONFIG
:
{
if
(
this
.
amlClusterConfig
===
undefined
)
{
this
.
log
.
error
(
'
aml cluster config is not initialized
'
);
break
;
}
this
.
amlTrialConfig
=
<
AMLTrialConfig
>
JSON
.
parse
(
value
);
// Validate to make sure codeDir doesn't have too many files
await
validateCodeDir
(
this
.
amlTrialConfig
.
codeDir
);
break
;
}
default
:
this
.
log
.
debug
(
`AML not proccessed metadata key: '
${
key
}
', value: '
${
value
}
'`
);
}
}
public
async
refreshEnvironmentsStatus
(
environments
:
EnvironmentInformation
[]):
Promise
<
void
>
{
environments
.
forEach
(
async
(
environment
)
=>
{
const
amlClient
=
(
environment
as
AMLEnvironmentInformation
).
amlClient
;
if
(
!
amlClient
)
{
throw
new
Error
(
'
AML client not initialized!
'
);
}
const
status
=
await
amlClient
.
updateStatus
(
environment
.
status
);
switch
(
status
.
toUpperCase
())
{
case
'
WAITING
'
:
case
'
RUNNING
'
:
case
'
QUEUED
'
:
// RUNNING status is set by runner, and ignore waiting status
break
;
case
'
COMPLETED
'
:
case
'
SUCCEEDED
'
:
environment
.
setFinalStatus
(
'
SUCCEEDED
'
);
break
;
case
'
FAILED
'
:
environment
.
setFinalStatus
(
'
FAILED
'
);
break
;
case
'
STOPPED
'
:
case
'
STOPPING
'
:
environment
.
setFinalStatus
(
'
USER_CANCELED
'
);
break
;
default
:
environment
.
setFinalStatus
(
'
UNKNOWN
'
);
}
});
}
public
async
startEnvironment
(
environment
:
EnvironmentInformation
):
Promise
<
void
>
{
if
(
this
.
amlClusterConfig
===
undefined
)
{
throw
new
Error
(
'
AML Cluster config is not initialized
'
);
}
if
(
this
.
amlTrialConfig
===
undefined
)
{
throw
new
Error
(
'
AML trial config is not initialized
'
);
}
const
amlEnvironment
:
AMLEnvironmentInformation
=
environment
as
AMLEnvironmentInformation
;
const
environmentLocalTempFolder
=
path
.
join
(
this
.
experimentRootDir
,
this
.
experimentId
,
"
environment-temp
"
);
environment
.
command
=
`import os\nos.system('
${
amlEnvironment
.
command
}
')`
;
await
fs
.
promises
.
writeFile
(
path
.
join
(
environmentLocalTempFolder
,
'
nni_script.py
'
),
amlEnvironment
.
command
,{
encoding
:
'
utf8
'
});
const
amlClient
=
new
AMLClient
(
this
.
amlClusterConfig
.
subscriptionId
,
this
.
amlClusterConfig
.
resourceGroup
,
this
.
amlClusterConfig
.
workspaceName
,
this
.
experimentId
,
this
.
amlTrialConfig
.
computeTarget
,
this
.
amlTrialConfig
.
image
,
'
nni_script.py
'
,
environmentLocalTempFolder
);
amlEnvironment
.
id
=
await
amlClient
.
submit
();
amlEnvironment
.
trackingUrl
=
await
amlClient
.
getTrackingUrl
();
amlEnvironment
.
amlClient
=
amlClient
;
}
public
async
stopEnvironment
(
environment
:
EnvironmentInformation
):
Promise
<
void
>
{
const
amlEnvironment
:
AMLEnvironmentInformation
=
environment
as
AMLEnvironmentInformation
;
const
amlClient
=
amlEnvironment
.
amlClient
;
if
(
!
amlClient
)
{
throw
new
Error
(
'
AML client not initialized!
'
);
}
amlClient
.
stop
();
}
}
src/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts
View file @
93f96d4f
...
@@ -167,8 +167,9 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
...
@@ -167,8 +167,9 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
}
}
// Step 1. Prepare PAI job configuration
// Step 1. Prepare PAI job configuration
environment
.
runnerWorkingFolder
=
`
${
this
.
paiTrialConfig
.
containerNFSMountPath
}
/
${
this
.
experimentId
}
/envs/
${
environment
.
id
}
`
;
const
environmentRoot
=
`
${
this
.
paiTrialConfig
.
containerNFSMountPath
}
/
${
this
.
experimentId
}
`
;
environment
.
command
=
`cd
${
environment
.
runnerWorkingFolder
}
&&
${
environment
.
command
}
`
environment
.
runnerWorkingFolder
=
`
${
environmentRoot
}
/envs/
${
environment
.
id
}
`
;
environment
.
command
=
`cd
${
environmentRoot
}
&&
${
environment
.
command
}
`
environment
.
trackingUrl
=
`
${
this
.
protocol
}
://
${
this
.
paiClusterConfig
.
host
}
/job-detail.html?username=
${
this
.
paiClusterConfig
.
userName
}
&jobName=
${
environment
.
jobId
}
`
environment
.
trackingUrl
=
`
${
this
.
protocol
}
://
${
this
.
paiClusterConfig
.
host
}
/job-detail.html?username=
${
this
.
paiClusterConfig
.
userName
}
&jobName=
${
environment
.
jobId
}
`
// Step 2. Generate Job Configuration in yaml format
// Step 2. Generate Job Configuration in yaml format
...
...
src/nni_manager/training_service/reusable/routerTrainingService.ts
View file @
93f96d4f
...
@@ -13,6 +13,7 @@ import { PAIClusterConfig } from '../pai/paiConfig';
...
@@ -13,6 +13,7 @@ import { PAIClusterConfig } from '../pai/paiConfig';
import
{
PAIK8STrainingService
}
from
'
../pai/paiK8S/paiK8STrainingService
'
;
import
{
PAIK8STrainingService
}
from
'
../pai/paiK8S/paiK8STrainingService
'
;
import
{
EnvironmentService
}
from
'
./environment
'
;
import
{
EnvironmentService
}
from
'
./environment
'
;
import
{
OpenPaiEnvironmentService
}
from
'
./environments/openPaiEnvironmentService
'
;
import
{
OpenPaiEnvironmentService
}
from
'
./environments/openPaiEnvironmentService
'
;
import
{
AMLEnvironmentService
}
from
'
./environments/amlEnvironmentService
'
;
import
{
MountedStorageService
}
from
'
./storages/mountedStorageService
'
;
import
{
MountedStorageService
}
from
'
./storages/mountedStorageService
'
;
import
{
StorageService
}
from
'
./storageService
'
;
import
{
StorageService
}
from
'
./storageService
'
;
import
{
TrialDispatcher
}
from
'
./trialDispatcher
'
;
import
{
TrialDispatcher
}
from
'
./trialDispatcher
'
;
...
@@ -120,6 +121,25 @@ class RouterTrainingService implements TrainingService {
...
@@ -120,6 +121,25 @@ class RouterTrainingService implements TrainingService {
}
}
await
this
.
internalTrainingService
.
setClusterMetadata
(
key
,
value
);
await
this
.
internalTrainingService
.
setClusterMetadata
(
key
,
value
);
this
.
metaDataCache
.
clear
();
}
else
if
(
key
===
TrialConfigMetadataKey
.
AML_CLUSTER_CONFIG
)
{
this
.
internalTrainingService
=
component
.
get
(
TrialDispatcher
);
Container
.
bind
(
EnvironmentService
)
.
to
(
AMLEnvironmentService
)
.
scope
(
Scope
.
Singleton
);
for
(
const
[
key
,
value
]
of
this
.
metaDataCache
)
{
if
(
this
.
internalTrainingService
===
undefined
)
{
throw
new
Error
(
"
TrainingService is not assigned!
"
);
}
await
this
.
internalTrainingService
.
setClusterMetadata
(
key
,
value
);
}
if
(
this
.
internalTrainingService
===
undefined
)
{
throw
new
Error
(
"
TrainingService is not assigned!
"
);
}
await
this
.
internalTrainingService
.
setClusterMetadata
(
key
,
value
);
this
.
metaDataCache
.
clear
();
this
.
metaDataCache
.
clear
();
}
else
{
}
else
{
this
.
log
.
debug
(
`caching metadata key:{} value:{}, as training service is not determined.`
);
this
.
log
.
debug
(
`caching metadata key:{} value:{}, as training service is not determined.`
);
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment