Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
93f96d4f
Unverified
Commit
93f96d4f
authored
Jul 01, 2020
by
SparkSnail
Committed by
GitHub
Jul 01, 2020
Browse files
Support aml (#2615)
parent
f5caa193
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
660 additions
and
9 deletions
+660
-9
docs/en_US/TrainingService/AMLMode.md
docs/en_US/TrainingService/AMLMode.md
+66
-0
docs/en_US/training_services.rst
docs/en_US/training_services.rst
+1
-0
docs/img/aml_account.png
docs/img/aml_account.png
+0
-0
examples/trials/mnist-pytorch/config_aml.yml
examples/trials/mnist-pytorch/config_aml.yml
+25
-0
examples/trials/mnist-tfv1/config_aml.yml
examples/trials/mnist-tfv1/config_aml.yml
+25
-0
src/nni_manager/config/aml/amlUtil.py
src/nni_manager/config/aml/amlUtil.py
+56
-0
src/nni_manager/main.ts
src/nni_manager/main.ts
+6
-2
src/nni_manager/package.json
src/nni_manager/package.json
+1
-0
src/nni_manager/rest_server/restValidationSchemas.ts
src/nni_manager/rest_server/restValidationSchemas.ts
+7
-0
src/nni_manager/training_service/common/trialConfigMetadataKey.ts
...manager/training_service/common/trialConfigMetadataKey.ts
+1
-0
src/nni_manager/training_service/reusable/aml/amlClient.ts
src/nni_manager/training_service/reusable/aml/amlClient.ts
+125
-0
src/nni_manager/training_service/reusable/aml/amlConfig.ts
src/nni_manager/training_service/reusable/aml/amlConfig.ts
+39
-0
src/nni_manager/training_service/reusable/channels/amlCommandChannel.ts
...r/training_service/reusable/channels/amlCommandChannel.ts
+120
-0
src/nni_manager/training_service/reusable/channels/fileCommandChannel.ts
.../training_service/reusable/channels/fileCommandChannel.ts
+10
-4
src/nni_manager/training_service/reusable/channels/webCommandChannel.ts
...r/training_service/reusable/channels/webCommandChannel.ts
+4
-0
src/nni_manager/training_service/reusable/commandChannel.ts
src/nni_manager/training_service/reusable/commandChannel.ts
+3
-0
src/nni_manager/training_service/reusable/environment.ts
src/nni_manager/training_service/reusable/environment.ts
+1
-1
src/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts
...ng_service/reusable/environments/amlEnvironmentService.ts
+147
-0
src/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts
...ervice/reusable/environments/openPaiEnvironmentService.ts
+3
-2
src/nni_manager/training_service/reusable/routerTrainingService.ts
...anager/training_service/reusable/routerTrainingService.ts
+20
-0
No files found.
docs/en_US/TrainingService/AMLMode.md
0 → 100644
View file @
93f96d4f
**Run an Experiment on Azure Machine Learning**
===
NNI supports running an experiment on
[
AML
](
https://azure.microsoft.com/en-us/services/machine-learning/
)
, called aml mode.
## Setup environment
Step 1. Install NNI, follow the install guide
[
here
](
../Tutorial/QuickStart.md
)
.
Step 2. Create AML account, follow the document
[
here
](
https://docs.microsoft.com/en-us/azure/machine-learning/how-to-manage-workspace-cli
)
.
Step 3. Get your account information.

Step4. Install AML package environment.
```
python3 -m pip install azureml --user
python3 -m pip install azureml-sdk --user
```
## Run an experiment
Use
`examples/trials/mnist-tfv1`
as an example. The NNI config YAML file's content is like:
```
yaml
authorName
:
default
experimentName
:
example_mnist
trialConcurrency
:
1
maxExecDuration
:
1h
maxTrialNum
:
10
trainingServicePlatform
:
aml
searchSpacePath
:
search_space.json
#choice: true, false
useAnnotation
:
false
tuner
:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName
:
TPE
classArgs
:
#choice: maximize, minimize
optimize_mode
:
maximize
trial
:
command
:
python3 mnist.py
codeDir
:
.
computeTarget
:
${replace_to_your_computeTarget}
image
:
msranni/nni
amlConfig
:
subscriptionId
:
${replace_to_your_subscriptionId}
resourceGroup
:
${replace_to_your_resourceGroup}
workspaceName
:
${replace_to_your_workspaceName}
```
Note: You should set
`trainingServicePlatform: aml`
in NNI config YAML file if you want to start experiment in aml mode.
Compared with
[
LocalMode
](
LocalMode.md
)
trial configuration in aml mode have these additional keys:
*
computeTarget
*
required key. The computer cluster name you want to use in your AML workspace.
*
image
*
required key. The docker image name used in job.
amlConfig:
*
subscriptionId
*
the subscriptionId of your account
*
resourceGroup
*
the resourceGroup of your account
*
workspaceName
*
the workspaceName of your account
\ No newline at end of file
docs/en_US/training_services.rst
View file @
93f96d4f
...
...
@@ -10,3 +10,4 @@ Introduction to NNI Training Services
Kubeflow<./TrainingService/KubeflowMode>
FrameworkController<./TrainingService/FrameworkControllerMode>
DLTS<./TrainingService/DLTSMode>
AML<./TrainingService/AMLMode>
docs/img/aml_account.png
0 → 100644
View file @
93f96d4f
24.7 KB
examples/trials/mnist-pytorch/config_aml.yml
0 → 100644
View file @
93f96d4f
authorName
:
default
experimentName
:
example_mnist_pytorch
trialConcurrency
:
1
maxExecDuration
:
1h
maxTrialNum
:
10
trainingServicePlatform
:
aml
searchSpacePath
:
search_space.json
#choice: true, false
useAnnotation
:
false
tuner
:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName
:
TPE
classArgs
:
#choice: maximize, minimize
optimize_mode
:
maximize
trial
:
command
:
python3 mnist.py
codeDir
:
.
computeTarget
:
${replace_to_your_computeTarget}
image
:
msranni/nni
amlConfig
:
subscriptionId
:
${replace_to_your_subscriptionId}
resourceGroup
:
${replace_to_your_resourceGroup}
workspaceName
:
${replace_to_your_workspaceName}
examples/trials/mnist-tfv1/config_aml.yml
0 → 100644
View file @
93f96d4f
authorName
:
default
experimentName
:
example_mnist
trialConcurrency
:
1
maxExecDuration
:
1h
maxTrialNum
:
10
trainingServicePlatform
:
aml
searchSpacePath
:
search_space.json
#choice: true, false
useAnnotation
:
false
tuner
:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName
:
TPE
classArgs
:
#choice: maximize, minimize
optimize_mode
:
maximize
trial
:
command
:
python3 mnist.py
codeDir
:
.
computeTarget
:
${replace_to_your_computeTarget}
image
:
msranni/nni
amlConfig
:
subscriptionId
:
${replace_to_your_subscriptionId}
resourceGroup
:
${replace_to_your_resourceGroup}
workspaceName
:
${replace_to_your_workspaceName}
src/nni_manager/config/aml/amlUtil.py
0 → 100644
View file @
93f96d4f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
sys
import
time
import
json
from
argparse
import
ArgumentParser
from
azureml.core
import
Experiment
,
RunConfiguration
,
ScriptRunConfig
from
azureml.core.compute
import
ComputeTarget
from
azureml.core.run
import
RUNNING_STATES
,
RunStatus
,
Run
from
azureml.core
import
Workspace
from
azureml.core.conda_dependencies
import
CondaDependencies
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
()
parser
.
add_argument
(
'--subscription_id'
,
help
=
'the subscription id of aml'
)
parser
.
add_argument
(
'--resource_group'
,
help
=
'the resource group of aml'
)
parser
.
add_argument
(
'--workspace_name'
,
help
=
'the workspace name of aml'
)
parser
.
add_argument
(
'--compute_target'
,
help
=
'the compute cluster name of aml'
)
parser
.
add_argument
(
'--docker_image'
,
help
=
'the docker image of job'
)
parser
.
add_argument
(
'--experiment_name'
,
help
=
'the experiment name'
)
parser
.
add_argument
(
'--script_dir'
,
help
=
'script directory'
)
parser
.
add_argument
(
'--script_name'
,
help
=
'script name'
)
args
=
parser
.
parse_args
()
ws
=
Workspace
(
args
.
subscription_id
,
args
.
resource_group
,
args
.
workspace_name
)
compute_target
=
ComputeTarget
(
workspace
=
ws
,
name
=
args
.
compute_target
)
experiment
=
Experiment
(
ws
,
args
.
experiment_name
)
run_config
=
RunConfiguration
()
dependencies
=
CondaDependencies
()
dependencies
.
add_pip_package
(
"azureml-sdk"
)
dependencies
.
add_pip_package
(
"azureml"
)
run_config
.
environment
.
python
.
conda_dependencies
=
dependencies
run_config
.
environment
.
docker
.
enabled
=
True
run_config
.
environment
.
docker
.
base_image
=
args
.
docker_image
run_config
.
target
=
compute_target
run_config
.
node_count
=
1
config
=
ScriptRunConfig
(
source_directory
=
args
.
script_dir
,
script
=
args
.
script_name
,
run_config
=
run_config
)
run
=
experiment
.
submit
(
config
)
print
(
run
.
get_details
()[
"runId"
])
while
True
:
line
=
sys
.
stdin
.
readline
().
rstrip
()
if
line
==
'update_status'
:
print
(
'status:'
+
run
.
get_status
())
elif
line
==
'tracking_url'
:
print
(
'tracking_url:'
+
run
.
get_portal_url
())
elif
line
==
'stop'
:
run
.
cancel
()
exit
(
0
)
elif
line
==
'receive'
:
print
(
'receive:'
+
json
.
dumps
(
run
.
get_metrics
()))
elif
line
:
items
=
line
.
split
(
':'
)
if
items
[
0
]
==
'command'
:
run
.
log
(
'nni_manager'
,
line
[
8
:])
src/nni_manager/main.ts
View file @
93f96d4f
...
...
@@ -65,6 +65,10 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
Container
.
bind
(
TrainingService
)
.
to
(
DLTSTrainingService
)
.
scope
(
Scope
.
Singleton
);
}
else
if
(
platformMode
===
'
aml
'
)
{
Container
.
bind
(
TrainingService
)
.
to
(
RouterTrainingService
)
.
scope
(
Scope
.
Singleton
);
}
else
{
throw
new
Error
(
`Error: unsupported mode:
${
platformMode
}
`
);
}
...
...
@@ -93,7 +97,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
function
usage
():
void
{
console
.
info
(
'
usage: node main.js --port <port> --mode
\
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>
'
);
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn
/aml
> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>
'
);
}
const
strPort
:
string
=
parseArg
([
'
--port
'
,
'
-p
'
]);
...
...
@@ -113,7 +117,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
const
port
:
number
=
parseInt
(
strPort
,
10
);
const
mode
:
string
=
parseArg
([
'
--mode
'
,
'
-m
'
]);
if
(
!
[
'
local
'
,
'
remote
'
,
'
pai
'
,
'
kubeflow
'
,
'
frameworkcontroller
'
,
'
paiYarn
'
,
'
dlts
'
].
includes
(
mode
))
{
if
(
!
[
'
local
'
,
'
remote
'
,
'
pai
'
,
'
kubeflow
'
,
'
frameworkcontroller
'
,
'
paiYarn
'
,
'
dlts
'
,
'
aml
'
].
includes
(
mode
))
{
console
.
log
(
`FATAL: unknown mode:
${
mode
}
`
);
usage
();
process
.
exit
(
1
);
...
...
src/nni_manager/package.json
View file @
93f96d4f
...
...
@@ -19,6 +19,7 @@
"ignore"
:
"^5.1.4"
,
"js-base64"
:
"^2.4.9"
,
"kubernetes-client"
:
"^6.5.0"
,
"python-shell"
:
"^2.0.1"
,
"rx"
:
"^4.1.0"
,
"sqlite3"
:
"^4.0.2"
,
"ssh2"
:
"^0.6.1"
,
...
...
src/nni_manager/rest_server/restValidationSchemas.ts
View file @
93f96d4f
...
...
@@ -39,6 +39,8 @@ export namespace ValidationSchemas {
nniManagerNFSMountPath
:
joi
.
string
().
min
(
1
),
containerNFSMountPath
:
joi
.
string
().
min
(
1
),
paiConfigPath
:
joi
.
string
(),
computeTarget
:
joi
.
string
(),
nodeCount
:
joi
.
number
(),
paiStorageConfigName
:
joi
.
string
().
min
(
1
),
nasMode
:
joi
.
string
().
valid
(
'
classic_mode
'
,
'
enas_mode
'
,
'
oneshot_mode
'
,
'
darts_mode
'
),
portList
:
joi
.
array
().
items
(
joi
.
object
({
...
...
@@ -150,6 +152,11 @@ export namespace ValidationSchemas {
email
:
joi
.
string
().
min
(
1
),
password
:
joi
.
string
().
min
(
1
)
}),
aml_config
:
joi
.
object
({
// eslint-disable-line @typescript-eslint/camelcase
subscriptionId
:
joi
.
string
().
min
(
1
),
resourceGroup
:
joi
.
string
().
min
(
1
),
workspaceName
:
joi
.
string
().
min
(
1
)
}),
nni_manager_ip
:
joi
.
object
({
// eslint-disable-line @typescript-eslint/camelcase
nniManagerIp
:
joi
.
string
().
min
(
1
)
})
...
...
src/nni_manager/training_service/common/trialConfigMetadataKey.ts
View file @
93f96d4f
...
...
@@ -19,6 +19,7 @@ export enum TrialConfigMetadataKey {
NNI_MANAGER_IP
=
'
nni_manager_ip
'
,
FRAMEWORKCONTROLLER_CLUSTER_CONFIG
=
'
frameworkcontroller_config
'
,
DLTS_CLUSTER_CONFIG
=
'
dlts_config
'
,
AML_CLUSTER_CONFIG
=
'
aml_config
'
,
VERSION_CHECK
=
'
version_check
'
,
LOG_COLLECTION
=
'
log_collection
'
}
src/nni_manager/training_service/reusable/aml/amlClient.ts
0 → 100644
View file @
93f96d4f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
PythonShell
}
from
'
python-shell
'
;
export
class
AMLClient
{
public
subscriptionId
:
string
;
public
resourceGroup
:
string
;
public
workspaceName
:
string
;
public
experimentId
:
string
;
public
image
:
string
;
public
scriptName
:
string
;
public
pythonShellClient
:
undefined
|
PythonShell
;
public
codeDir
:
string
;
public
computeTarget
:
string
;
constructor
(
subscriptionId
:
string
,
resourceGroup
:
string
,
workspaceName
:
string
,
experimentId
:
string
,
computeTarget
:
string
,
image
:
string
,
scriptName
:
string
,
codeDir
:
string
,
)
{
this
.
subscriptionId
=
subscriptionId
;
this
.
resourceGroup
=
resourceGroup
;
this
.
workspaceName
=
workspaceName
;
this
.
experimentId
=
experimentId
;
this
.
image
=
image
;
this
.
scriptName
=
scriptName
;
this
.
codeDir
=
codeDir
;
this
.
computeTarget
=
computeTarget
;
}
public
submit
():
Promise
<
string
>
{
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
this
.
pythonShellClient
=
new
PythonShell
(
'
amlUtil.py
'
,
{
scriptPath
:
'
./config/aml
'
,
pythonOptions
:
[
'
-u
'
],
// get print results in real-time
args
:
[
'
--subscription_id
'
,
this
.
subscriptionId
,
'
--resource_group
'
,
this
.
resourceGroup
,
'
--workspace_name
'
,
this
.
workspaceName
,
'
--compute_target
'
,
this
.
computeTarget
,
'
--docker_image
'
,
this
.
image
,
'
--experiment_name
'
,
`nni_exp_
${
this
.
experimentId
}
`
,
'
--script_dir
'
,
this
.
codeDir
,
'
--script_name
'
,
this
.
scriptName
]
});
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
envId
:
any
)
{
// received a message sent from the Python script (a simple "print" statement)
deferred
.
resolve
(
envId
);
});
return
deferred
.
promise
;
}
public
stop
():
void
{
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
}
this
.
pythonShellClient
.
send
(
'
stop
'
);
}
public
getTrackingUrl
():
Promise
<
string
>
{
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
}
this
.
pythonShellClient
.
send
(
'
tracking_url
'
);
let
trackingUrl
=
''
;
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
status
:
any
)
{
const
items
=
status
.
split
(
'
:
'
);
if
(
items
[
0
]
===
'
tracking_url
'
)
{
trackingUrl
=
items
.
splice
(
1
,
items
.
length
).
join
(
''
)
}
deferred
.
resolve
(
trackingUrl
);
});
return
deferred
.
promise
;
}
public
updateStatus
(
oldStatus
:
string
):
Promise
<
string
>
{
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
}
let
newStatus
=
oldStatus
;
this
.
pythonShellClient
.
send
(
'
update_status
'
);
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
status
:
any
)
{
const
items
=
status
.
split
(
'
:
'
);
if
(
items
[
0
]
===
'
status
'
)
{
newStatus
=
items
.
splice
(
1
,
items
.
length
).
join
(
''
)
}
deferred
.
resolve
(
newStatus
);
});
return
deferred
.
promise
;
}
public
sendCommand
(
message
:
string
):
void
{
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
}
this
.
pythonShellClient
.
send
(
`command:
${
message
}
`
);
}
public
receiveCommand
():
Promise
<
any
>
{
const
deferred
:
Deferred
<
any
>
=
new
Deferred
<
any
>
();
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
}
this
.
pythonShellClient
.
send
(
'
receive
'
);
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
command
:
any
)
{
const
items
=
command
.
split
(
'
:
'
)
if
(
items
[
0
]
===
'
receive
'
)
{
deferred
.
resolve
(
JSON
.
parse
(
command
.
slice
(
8
)))
}
});
return
deferred
.
promise
;
}
}
src/nni_manager/training_service/reusable/aml/amlConfig.ts
0 → 100644
View file @
93f96d4f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
{
TrialConfig
}
from
'
../../common/trialConfig
'
;
import
{
EnvironmentInformation
}
from
'
../environment
'
;
import
{
AMLClient
}
from
'
../aml/amlClient
'
;
export
class
AMLClusterConfig
{
public
readonly
subscriptionId
:
string
;
public
readonly
resourceGroup
:
string
;
public
readonly
workspaceName
:
string
;
constructor
(
subscriptionId
:
string
,
resourceGroup
:
string
,
workspaceName
:
string
)
{
this
.
subscriptionId
=
subscriptionId
;
this
.
resourceGroup
=
resourceGroup
;
this
.
workspaceName
=
workspaceName
;
}
}
export
class
AMLTrialConfig
extends
TrialConfig
{
public
readonly
image
:
string
;
public
readonly
command
:
string
;
public
readonly
codeDir
:
string
;
public
readonly
computeTarget
:
string
;
constructor
(
codeDir
:
string
,
command
:
string
,
image
:
string
,
computeTarget
:
string
)
{
super
(
""
,
codeDir
,
0
);
this
.
codeDir
=
codeDir
;
this
.
command
=
command
;
this
.
image
=
image
;
this
.
computeTarget
=
computeTarget
;
}
}
export
class
AMLEnvironmentInformation
extends
EnvironmentInformation
{
public
amlClient
?:
AMLClient
;
}
src/nni_manager/training_service/reusable/channels/amlCommandChannel.ts
0 → 100644
View file @
93f96d4f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
{
EventEmitter
}
from
'
events
'
;
import
{
delay
}
from
"
../../../common/utils
"
;
import
{
AMLEnvironmentInformation
}
from
'
../aml/amlConfig
'
;
import
{
CommandChannel
,
RunnerConnection
}
from
"
../commandChannel
"
;
import
{
Channel
,
EnvironmentInformation
}
from
"
../environment
"
;
class
AMLRunnerConnection
extends
RunnerConnection
{
}
export
class
AMLCommandChannel
extends
CommandChannel
{
private
stopping
:
boolean
=
false
;
private
currentMessageIndex
:
number
=
-
1
;
private
sendQueues
:
[
EnvironmentInformation
,
string
][]
=
[];
private
readonly
NNI_METRICS_PATTERN
:
string
=
`NNISDK_MEb'(?<metrics>.*?)'`
;
public
constructor
(
commandEmitter
:
EventEmitter
)
{
super
(
commandEmitter
);
}
public
get
channelName
():
Channel
{
return
"
aml
"
;
}
public
async
config
(
_key
:
string
,
_value
:
any
):
Promise
<
void
>
{
// do nothing
}
public
async
start
():
Promise
<
void
>
{
// do nothing
}
public
async
stop
():
Promise
<
void
>
{
this
.
stopping
=
true
;
}
public
async
run
():
Promise
<
void
>
{
// start command loops
await
Promise
.
all
([
this
.
receiveLoop
(),
this
.
sendLoop
()
]);
}
protected
async
sendCommandInternal
(
environment
:
EnvironmentInformation
,
message
:
string
):
Promise
<
void
>
{
this
.
sendQueues
.
push
([
environment
,
message
]);
}
protected
createRunnerConnection
(
environment
:
EnvironmentInformation
):
RunnerConnection
{
return
new
AMLRunnerConnection
(
environment
);
}
private
async
sendLoop
():
Promise
<
void
>
{
const
intervalSeconds
=
0.5
;
while
(
!
this
.
stopping
)
{
const
start
=
new
Date
();
if
(
this
.
sendQueues
.
length
>
0
)
{
while
(
this
.
sendQueues
.
length
>
0
)
{
const
item
=
this
.
sendQueues
.
shift
();
if
(
item
===
undefined
)
{
break
;
}
const
environment
=
item
[
0
];
const
message
=
item
[
1
];
const
amlClient
=
(
environment
as
AMLEnvironmentInformation
).
amlClient
;
if
(
!
amlClient
)
{
throw
new
Error
(
'
aml client not initialized!
'
);
}
amlClient
.
sendCommand
(
message
);
}
}
const
end
=
new
Date
();
const
delayMs
=
intervalSeconds
*
1000
-
(
end
.
valueOf
()
-
start
.
valueOf
());
if
(
delayMs
>
0
)
{
await
delay
(
delayMs
);
}
}
}
private
async
receiveLoop
():
Promise
<
void
>
{
const
intervalSeconds
=
2
;
while
(
!
this
.
stopping
)
{
const
start
=
new
Date
();
const
runnerConnections
=
[...
this
.
runnerConnections
.
values
()]
as
AMLRunnerConnection
[];
for
(
const
runnerConnection
of
runnerConnections
)
{
// to loop all commands
const
amlClient
=
(
runnerConnection
.
environment
as
AMLEnvironmentInformation
).
amlClient
;
if
(
!
amlClient
)
{
throw
new
Error
(
'
AML client not initialized!
'
);
}
const
command
=
await
amlClient
.
receiveCommand
();
if
(
command
&&
Object
.
prototype
.
hasOwnProperty
.
call
(
command
,
"
trial_runner
"
))
{
const
messages
=
command
[
'
trial_runner
'
];
if
(
messages
)
{
if
(
messages
instanceof
Object
&&
this
.
currentMessageIndex
<
messages
.
length
-
1
)
{
for
(
let
index
=
this
.
currentMessageIndex
+
1
;
index
<
messages
.
length
;
index
++
)
{
this
.
handleCommand
(
runnerConnection
.
environment
,
messages
[
index
]);
}
this
.
currentMessageIndex
=
messages
.
length
-
1
;
}
else
if
(
this
.
currentMessageIndex
===
-
1
){
this
.
handleCommand
(
runnerConnection
.
environment
,
messages
);
this
.
currentMessageIndex
+=
1
;
}
}
}
}
const
end
=
new
Date
();
const
delayMs
=
intervalSeconds
*
1000
-
(
end
.
valueOf
()
-
start
.
valueOf
());
if
(
delayMs
>
0
)
{
await
delay
(
delayMs
);
}
}
}
}
src/nni_manager/training_service/reusable/channels/fileCommandChannel.ts
View file @
93f96d4f
...
...
@@ -6,7 +6,7 @@
import
*
as
component
from
"
../../../common/component
"
;
import
{
delay
}
from
"
../../../common/utils
"
;
import
{
CommandChannel
,
RunnerConnection
}
from
"
../commandChannel
"
;
import
{
EnvironmentInformation
,
Channel
}
from
"
../environment
"
;
import
{
Channel
,
EnvironmentInformation
}
from
"
../environment
"
;
import
{
StorageService
}
from
"
../storageService
"
;
class
FileHandler
{
...
...
@@ -38,15 +38,21 @@ export class FileCommandChannel extends CommandChannel {
}
public
async
start
():
Promise
<
void
>
{
// start command loops
this
.
receiveLoop
();
this
.
sendLoop
();
// do nothing
}
public
async
stop
():
Promise
<
void
>
{
this
.
stopping
=
true
;
}
public
async
run
():
Promise
<
void
>
{
// start command loops
await
Promise
.
all
([
this
.
receiveLoop
(),
this
.
sendLoop
()
]);
}
protected
async
sendCommandInternal
(
environment
:
EnvironmentInformation
,
message
:
string
):
Promise
<
void
>
{
this
.
sendQueues
.
push
([
environment
,
message
]);
}
...
...
src/nni_manager/training_service/reusable/channels/webCommandChannel.ts
View file @
93f96d4f
...
...
@@ -66,6 +66,10 @@ export class WebCommandChannel extends CommandChannel {
}
}
public
async
run
():
Promise
<
void
>
{
// do nothing
}
protected
async
sendCommandInternal
(
environment
:
EnvironmentInformation
,
message
:
string
):
Promise
<
void
>
{
if
(
this
.
webSocketServer
===
undefined
)
{
throw
new
Error
(
`WebCommandChannel: uninitialized!`
)
...
...
src/nni_manager/training_service/reusable/commandChannel.ts
View file @
93f96d4f
...
...
@@ -59,6 +59,9 @@ export abstract class CommandChannel {
public
abstract
start
():
Promise
<
void
>
;
public
abstract
stop
():
Promise
<
void
>
;
// Pull-based command channels need loop to check messages, the loop should be started with await here.
public
abstract
run
():
Promise
<
void
>
;
protected
abstract
sendCommandInternal
(
environment
:
EnvironmentInformation
,
message
:
string
):
Promise
<
void
>
;
protected
abstract
createRunnerConnection
(
environment
:
EnvironmentInformation
):
RunnerConnection
;
...
...
src/nni_manager/training_service/reusable/environment.ts
View file @
93f96d4f
...
...
@@ -14,7 +14,6 @@ import { CommandChannel } from "./commandChannel";
export
type
EnvironmentStatus
=
'
UNKNOWN
'
|
'
WAITING
'
|
'
RUNNING
'
|
'
SUCCEEDED
'
|
'
FAILED
'
|
'
USER_CANCELED
'
;
export
type
Channel
=
"
web
"
|
"
file
"
|
"
aml
"
|
"
ut
"
;
export
class
EnvironmentInformation
{
private
log
:
Logger
;
...
...
@@ -65,6 +64,7 @@ export class EnvironmentInformation {
}
}
}
export
abstract
class
EnvironmentService
{
public
abstract
get
hasStorageService
():
boolean
;
...
...
src/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts
0 → 100644
View file @
93f96d4f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
path
from
'
path
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
TrialConfigMetadataKey
}
from
'
../../common/trialConfigMetadataKey
'
;
import
{
AMLClusterConfig
,
AMLTrialConfig
}
from
'
../aml/amlConfig
'
;
import
{
EnvironmentInformation
,
EnvironmentService
}
from
'
../environment
'
;
import
{
AMLEnvironmentInformation
}
from
'
../aml/amlConfig
'
;
import
{
AMLClient
}
from
'
../aml/amlClient
'
;
import
{
NNIManagerIpConfig
,
}
from
'
../../../common/trainingService
'
;
import
{
validateCodeDir
}
from
'
../../common/util
'
;
import
{
getExperimentRootDir
}
from
'
../../../common/utils
'
;
import
{
AMLCommandChannel
}
from
'
../channels/amlCommandChannel
'
;
import
{
CommandChannel
}
from
"
../commandChannel
"
;
import
{
EventEmitter
}
from
"
events
"
;
/**
* Collector PAI jobs info from PAI cluster, and update pai job status locally
*/
@
component
.
Singleton
export
class
AMLEnvironmentService
extends
EnvironmentService
{
private
readonly
log
:
Logger
=
getLogger
();
public
amlClusterConfig
:
AMLClusterConfig
|
undefined
;
public
amlTrialConfig
:
AMLTrialConfig
|
undefined
;
private
amlJobConfig
:
any
;
private
stopping
:
boolean
=
false
;
private
versionCheck
:
boolean
=
true
;
private
isMultiPhase
:
boolean
=
false
;
private
nniVersion
?:
string
;
private
experimentId
:
string
;
private
nniManagerIpConfig
?:
NNIManagerIpConfig
;
private
experimentRootDir
:
string
;
constructor
()
{
super
();
this
.
experimentId
=
getExperimentId
();
this
.
experimentRootDir
=
getExperimentRootDir
();
}
public
get
hasStorageService
():
boolean
{
return
false
;
}
public
getCommandChannel
(
commandEmitter
:
EventEmitter
):
CommandChannel
{
return
new
AMLCommandChannel
(
commandEmitter
);
}
public
createEnviornmentInfomation
(
envId
:
string
,
envName
:
string
):
EnvironmentInformation
{
return
new
AMLEnvironmentInformation
(
envId
,
envName
);
}
public
async
config
(
key
:
string
,
value
:
string
):
Promise
<
void
>
{
switch
(
key
)
{
case
TrialConfigMetadataKey
.
AML_CLUSTER_CONFIG
:
this
.
amlClusterConfig
=
<
AMLClusterConfig
>
JSON
.
parse
(
value
);
break
;
case
TrialConfigMetadataKey
.
TRIAL_CONFIG
:
{
if
(
this
.
amlClusterConfig
===
undefined
)
{
this
.
log
.
error
(
'
aml cluster config is not initialized
'
);
break
;
}
this
.
amlTrialConfig
=
<
AMLTrialConfig
>
JSON
.
parse
(
value
);
// Validate to make sure codeDir doesn't have too many files
await
validateCodeDir
(
this
.
amlTrialConfig
.
codeDir
);
break
;
}
default
:
this
.
log
.
debug
(
`AML not proccessed metadata key: '
${
key
}
', value: '
${
value
}
'`
);
}
}
public
async
refreshEnvironmentsStatus
(
environments
:
EnvironmentInformation
[]):
Promise
<
void
>
{
environments
.
forEach
(
async
(
environment
)
=>
{
const
amlClient
=
(
environment
as
AMLEnvironmentInformation
).
amlClient
;
if
(
!
amlClient
)
{
throw
new
Error
(
'
AML client not initialized!
'
);
}
const
status
=
await
amlClient
.
updateStatus
(
environment
.
status
);
switch
(
status
.
toUpperCase
())
{
case
'
WAITING
'
:
case
'
RUNNING
'
:
case
'
QUEUED
'
:
// RUNNING status is set by runner, and ignore waiting status
break
;
case
'
COMPLETED
'
:
case
'
SUCCEEDED
'
:
environment
.
setFinalStatus
(
'
SUCCEEDED
'
);
break
;
case
'
FAILED
'
:
environment
.
setFinalStatus
(
'
FAILED
'
);
break
;
case
'
STOPPED
'
:
case
'
STOPPING
'
:
environment
.
setFinalStatus
(
'
USER_CANCELED
'
);
break
;
default
:
environment
.
setFinalStatus
(
'
UNKNOWN
'
);
}
});
}
public
async
startEnvironment
(
environment
:
EnvironmentInformation
):
Promise
<
void
>
{
if
(
this
.
amlClusterConfig
===
undefined
)
{
throw
new
Error
(
'
AML Cluster config is not initialized
'
);
}
if
(
this
.
amlTrialConfig
===
undefined
)
{
throw
new
Error
(
'
AML trial config is not initialized
'
);
}
const
amlEnvironment
:
AMLEnvironmentInformation
=
environment
as
AMLEnvironmentInformation
;
const
environmentLocalTempFolder
=
path
.
join
(
this
.
experimentRootDir
,
this
.
experimentId
,
"
environment-temp
"
);
environment
.
command
=
`import os\nos.system('
${
amlEnvironment
.
command
}
')`
;
await
fs
.
promises
.
writeFile
(
path
.
join
(
environmentLocalTempFolder
,
'
nni_script.py
'
),
amlEnvironment
.
command
,{
encoding
:
'
utf8
'
});
const
amlClient
=
new
AMLClient
(
this
.
amlClusterConfig
.
subscriptionId
,
this
.
amlClusterConfig
.
resourceGroup
,
this
.
amlClusterConfig
.
workspaceName
,
this
.
experimentId
,
this
.
amlTrialConfig
.
computeTarget
,
this
.
amlTrialConfig
.
image
,
'
nni_script.py
'
,
environmentLocalTempFolder
);
amlEnvironment
.
id
=
await
amlClient
.
submit
();
amlEnvironment
.
trackingUrl
=
await
amlClient
.
getTrackingUrl
();
amlEnvironment
.
amlClient
=
amlClient
;
}
public
async
stopEnvironment
(
environment
:
EnvironmentInformation
):
Promise
<
void
>
{
const
amlEnvironment
:
AMLEnvironmentInformation
=
environment
as
AMLEnvironmentInformation
;
const
amlClient
=
amlEnvironment
.
amlClient
;
if
(
!
amlClient
)
{
throw
new
Error
(
'
AML client not initialized!
'
);
}
amlClient
.
stop
();
}
}
src/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts
View file @
93f96d4f
...
...
@@ -167,8 +167,9 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
}
// Step 1. Prepare PAI job configuration
environment
.
runnerWorkingFolder
=
`
${
this
.
paiTrialConfig
.
containerNFSMountPath
}
/
${
this
.
experimentId
}
/envs/
${
environment
.
id
}
`
;
environment
.
command
=
`cd
${
environment
.
runnerWorkingFolder
}
&&
${
environment
.
command
}
`
const
environmentRoot
=
`
${
this
.
paiTrialConfig
.
containerNFSMountPath
}
/
${
this
.
experimentId
}
`
;
environment
.
runnerWorkingFolder
=
`
${
environmentRoot
}
/envs/
${
environment
.
id
}
`
;
environment
.
command
=
`cd
${
environmentRoot
}
&&
${
environment
.
command
}
`
environment
.
trackingUrl
=
`
${
this
.
protocol
}
://
${
this
.
paiClusterConfig
.
host
}
/job-detail.html?username=
${
this
.
paiClusterConfig
.
userName
}
&jobName=
${
environment
.
jobId
}
`
// Step 2. Generate Job Configuration in yaml format
...
...
src/nni_manager/training_service/reusable/routerTrainingService.ts
View file @
93f96d4f
...
...
@@ -13,6 +13,7 @@ import { PAIClusterConfig } from '../pai/paiConfig';
import
{
PAIK8STrainingService
}
from
'
../pai/paiK8S/paiK8STrainingService
'
;
import
{
EnvironmentService
}
from
'
./environment
'
;
import
{
OpenPaiEnvironmentService
}
from
'
./environments/openPaiEnvironmentService
'
;
import
{
AMLEnvironmentService
}
from
'
./environments/amlEnvironmentService
'
;
import
{
MountedStorageService
}
from
'
./storages/mountedStorageService
'
;
import
{
StorageService
}
from
'
./storageService
'
;
import
{
TrialDispatcher
}
from
'
./trialDispatcher
'
;
...
...
@@ -120,6 +121,25 @@ class RouterTrainingService implements TrainingService {
}
await
this
.
internalTrainingService
.
setClusterMetadata
(
key
,
value
);
this
.
metaDataCache
.
clear
();
}
else
if
(
key
===
TrialConfigMetadataKey
.
AML_CLUSTER_CONFIG
)
{
this
.
internalTrainingService
=
component
.
get
(
TrialDispatcher
);
Container
.
bind
(
EnvironmentService
)
.
to
(
AMLEnvironmentService
)
.
scope
(
Scope
.
Singleton
);
for
(
const
[
key
,
value
]
of
this
.
metaDataCache
)
{
if
(
this
.
internalTrainingService
===
undefined
)
{
throw
new
Error
(
"
TrainingService is not assigned!
"
);
}
await
this
.
internalTrainingService
.
setClusterMetadata
(
key
,
value
);
}
if
(
this
.
internalTrainingService
===
undefined
)
{
throw
new
Error
(
"
TrainingService is not assigned!
"
);
}
await
this
.
internalTrainingService
.
setClusterMetadata
(
key
,
value
);
this
.
metaDataCache
.
clear
();
}
else
{
this
.
log
.
debug
(
`caching metadata key:{} value:{}, as training service is not determined.`
);
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment