Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
277e63f2
"docs/source/Tutorial/Tensorboard.rst" did not exist on "1418a366bc537ce9166b3e18cdbedd84ad406f8c"
Unverified
Commit
277e63f2
authored
May 27, 2021
by
liuzhe-lz
Committed by
GitHub
May 27, 2021
Browse files
Support 3rd-party training service (#3662)
parent
e349b440
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
277 additions
and
106 deletions
+277
-106
nni/experiment/config/util.py
nni/experiment/config/util.py
+14
-1
nni/runtime/default_config/training_services.json
nni/runtime/default_config/training_services.json
+1
-0
nni/runtime/platform/__init__.py
nni/runtime/platform/__init__.py
+1
-3
nni/tools/nnictl/nnictl.py
nni/tools/nnictl/nnictl.py
+18
-0
nni/tools/nnictl/ts_management.py
nni/tools/nnictl/ts_management.py
+77
-0
ts/nni_manager/common/log.ts
ts/nni_manager/common/log.ts
+8
-20
ts/nni_manager/common/nniConfig.ts
ts/nni_manager/common/nniConfig.ts
+37
-0
ts/nni_manager/common/pythonScript.ts
ts/nni_manager/common/pythonScript.ts
+33
-0
ts/nni_manager/common/utils.ts
ts/nni_manager/common/utils.ts
+7
-4
ts/nni_manager/core/nnimanager.ts
ts/nni_manager/core/nnimanager.ts
+4
-4
ts/nni_manager/main.ts
ts/nni_manager/main.ts
+0
-5
ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts
...ng_service/reusable/environments/amlEnvironmentService.ts
+3
-5
ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts
...ervice/reusable/environments/environmentServiceFactory.ts
+25
-14
ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts
..._service/reusable/environments/localEnvironmentService.ts
+4
-7
ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts
...ervice/reusable/environments/openPaiEnvironmentService.ts
+2
-3
ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts
...service/reusable/environments/remoteEnvironmentService.ts
+7
-9
ts/nni_manager/training_service/reusable/routerTrainingService.ts
...anager/training_service/reusable/routerTrainingService.ts
+12
-9
ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts
...er/training_service/reusable/test/trialDispatcher.test.ts
+1
-1
ts/nni_manager/training_service/reusable/trialDispatcher.ts
ts/nni_manager/training_service/reusable/trialDispatcher.ts
+23
-21
No files found.
nni/experiment/config/util.py
View file @
277e63f2
...
@@ -5,11 +5,15 @@
...
@@ -5,11 +5,15 @@
Miscellaneous utility functions.
Miscellaneous utility functions.
"""
"""
import
importlib
import
json
import
math
import
math
import
os.path
import
os.path
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Union
,
List
from
typing
import
Any
,
Dict
,
Optional
,
Union
,
List
import
nni.runtime.config
PathLike
=
Union
[
Path
,
str
]
PathLike
=
Union
[
Path
,
str
]
def
case_insensitive
(
key_or_kwargs
:
Union
[
str
,
Dict
[
str
,
Any
]])
->
Union
[
str
,
Dict
[
str
,
Any
]]:
def
case_insensitive
(
key_or_kwargs
:
Union
[
str
,
Dict
[
str
,
Any
]])
->
Union
[
str
,
Dict
[
str
,
Any
]]:
...
@@ -34,6 +38,14 @@ def training_service_config_factory(
...
@@ -34,6 +38,14 @@ def training_service_config_factory(
config
:
Union
[
List
,
Dict
]
=
None
,
config
:
Union
[
List
,
Dict
]
=
None
,
base_path
:
Optional
[
Path
]
=
None
):
# -> TrainingServiceConfig
base_path
:
Optional
[
Path
]
=
None
):
# -> TrainingServiceConfig
from
.common
import
TrainingServiceConfig
from
.common
import
TrainingServiceConfig
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
custom_ts_config_path
=
nni
.
runtime
.
config
.
get_config_file
(
'training_services.json'
)
custom_ts_config
=
json
.
load
(
custom_ts_config_path
.
open
())
for
custom_ts_pkg
in
custom_ts_config
.
keys
():
pkg
=
importlib
.
import_module
(
custom_ts_pkg
)
_config_class
=
pkg
.
nni_training_service_info
.
config_class
ts_configs
=
[]
ts_configs
=
[]
if
platform
is
not
None
:
if
platform
is
not
None
:
assert
config
is
None
assert
config
is
None
...
@@ -42,7 +54,8 @@ def training_service_config_factory(
...
@@ -42,7 +54,8 @@ def training_service_config_factory(
if
cls
.
platform
in
platforms
:
if
cls
.
platform
in
platforms
:
ts_configs
.
append
(
cls
())
ts_configs
.
append
(
cls
())
if
len
(
ts_configs
)
<
len
(
platforms
):
if
len
(
ts_configs
)
<
len
(
platforms
):
raise
RuntimeError
(
'There is unrecognized platform!'
)
bad
=
', '
.
join
(
set
(
platforms
)
-
set
(
ts_configs
))
raise
RuntimeError
(
f
'Bad training service platform:
{
bad
}
'
)
else
:
else
:
assert
config
is
not
None
assert
config
is
not
None
supported_platforms
=
{
cls
.
platform
:
cls
for
cls
in
TrainingServiceConfig
.
__subclasses__
()}
supported_platforms
=
{
cls
.
platform
:
cls
for
cls
in
TrainingServiceConfig
.
__subclasses__
()}
...
...
nni/runtime/default_config/training_services.json
0 → 100644
View file @
277e63f2
{}
nni/runtime/platform/__init__.py
View file @
277e63f2
...
@@ -9,7 +9,5 @@ if trial_env_vars.NNI_PLATFORM is None:
...
@@ -9,7 +9,5 @@ if trial_env_vars.NNI_PLATFORM is None:
from
.standalone
import
*
from
.standalone
import
*
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
from
.test
import
*
from
.test
import
*
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'dlts'
,
'aml'
,
'adl'
,
'hybrid'
):
from
.local
import
*
else
:
else
:
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
from
.local
import
*
nni/tools/nnictl/nnictl.py
View file @
277e63f2
...
@@ -16,6 +16,8 @@ from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment
...
@@ -16,6 +16,8 @@ from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment
save_experiment
,
load_experiment
save_experiment
,
load_experiment
from
.algo_management
import
algo_reg
,
algo_unreg
,
algo_show
,
algo_list
from
.algo_management
import
algo_reg
,
algo_unreg
,
algo_show
,
algo_list
from
.constants
import
DEFAULT_REST_PORT
from
.constants
import
DEFAULT_REST_PORT
from
.
import
ts_management
init
(
autoreset
=
True
)
init
(
autoreset
=
True
)
if
os
.
environ
.
get
(
'COVERAGE_PROCESS_START'
):
if
os
.
environ
.
get
(
'COVERAGE_PROCESS_START'
):
...
@@ -242,6 +244,22 @@ def parse_args():
...
@@ -242,6 +244,22 @@ def parse_args():
parser_algo_list
=
parser_algo_subparsers
.
add_parser
(
'list'
,
help
=
'list registered algorithms'
)
parser_algo_list
=
parser_algo_subparsers
.
add_parser
(
'list'
,
help
=
'list registered algorithms'
)
parser_algo_list
.
set_defaults
(
func
=
algo_list
)
parser_algo_list
.
set_defaults
(
func
=
algo_list
)
#parse trainingservice command
parser_ts
=
subparsers
.
add_parser
(
'trainingservice'
,
help
=
'control training service'
)
# add subparsers for parser_ts
parser_ts_subparsers
=
parser_ts
.
add_subparsers
()
parser_ts_reg
=
parser_ts_subparsers
.
add_parser
(
'register'
,
help
=
'register training service'
)
parser_ts_reg
.
add_argument
(
'--package'
,
dest
=
'package'
,
help
=
'package name'
,
required
=
True
)
parser_ts_reg
.
set_defaults
(
func
=
ts_management
.
register
)
parser_ts_unreg
=
parser_ts_subparsers
.
add_parser
(
'unregister'
,
help
=
'unregister training service'
)
parser_ts_unreg
.
add_argument
(
'--package'
,
dest
=
'package'
,
help
=
'package name'
,
required
=
True
)
parser_ts_unreg
.
set_defaults
(
func
=
ts_management
.
unregister
)
parser_ts_list
=
parser_ts_subparsers
.
add_parser
(
'list'
,
help
=
'list custom training services'
)
parser_ts_list
.
set_defaults
(
func
=
ts_management
.
list_services
)
# To show message that nnictl package command is replaced by nnictl algo, to be remove in the future release.
# To show message that nnictl package command is replaced by nnictl algo, to be remove in the future release.
def
show_messsage_for_nnictl_package
(
args
):
def
show_messsage_for_nnictl_package
(
args
):
print_error
(
'nnictl package command is replaced by nnictl algo, please run nnictl algo -h to show the usage'
)
print_error
(
'nnictl package command is replaced by nnictl algo, please run nnictl algo -h to show the usage'
)
...
...
nni/tools/nnictl/ts_management.py
0 → 100644
View file @
277e63f2
import
importlib
import
json
from
nni.runtime.config
import
get_config_file
from
.common_utils
import
print_error
,
print_green
_builtin_training_services
=
[
'local'
,
'remote'
,
'openpai'
,
'pai'
,
'aml'
,
'kubeflow'
,
'frameworkcontroller'
,
'adl'
,
]
def
register
(
args
):
if
args
.
package
in
_builtin_training_services
:
print_error
(
f
'
{
args
.
package
}
is a builtin training service'
)
return
try
:
module
=
importlib
.
import_module
(
args
.
package
)
except
Exception
:
print_error
(
f
'Cannot import package
{
args
.
package
}
'
)
return
try
:
info
=
module
.
nni_training_service_info
except
Exception
:
print_error
(
f
'Cannot read nni_training_service_info from
{
args
.
package
}
'
)
return
try
:
info
.
config_class
()
except
Exception
:
print_error
(
'Bad experiment config class'
)
return
try
:
service_config
=
{
'node_module_path'
:
info
.
node_module_path
,
'node_class_name'
:
info
.
node_class_name
,
}
json
.
dumps
(
service_config
)
except
Exception
:
print_error
(
'Bad node_module_path or bad node_class_name'
)
return
config
=
_load
()
update
=
args
.
package
in
config
config
[
args
.
package
]
=
service_config
_save
(
config
)
if
update
:
print_green
(
f
'Sucessfully updated
{
args
.
package
}
'
)
else
:
print_green
(
f
'Sucessfully registered
{
args
.
package
}
'
)
def
unregister
(
args
):
config
=
_load
()
if
args
.
package
not
in
config
:
print_error
(
f
'
{
args
.
package
}
is not a registered training service'
)
return
config
.
pop
(
args
.
package
,
None
)
_save
(
config
)
print_green
(
f
'Sucessfully unregistered
{
args
.
package
}
'
)
def
list_services
(
_
):
print
(
'
\n
'
.
join
(
_load
().
keys
()))
def
_load
():
return
json
.
load
(
get_config_file
(
'training_services.json'
).
open
())
def
_save
(
config
):
json
.
dump
(
config
,
get_config_file
(
'training_services.json'
).
open
(
'w'
),
indent
=
4
)
ts/nni_manager/common/log.ts
View file @
277e63f2
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
*
as
fs
from
'
fs
'
;
import
*
as
fs
from
'
fs
'
;
import
{
Writable
}
from
'
stream
'
;
import
{
Writable
}
from
'
stream
'
;
import
*
as
util
from
'
util
'
;
/* log level constants */
/* log level constants */
...
@@ -28,7 +29,6 @@ const levelNames = new Map<number, string>([
...
@@ -28,7 +29,6 @@ const levelNames = new Map<number, string>([
/* global_ states */
/* global_ states */
let
logFile
:
Writable
|
null
=
null
;
let
logLevel
:
number
=
0
;
let
logLevel
:
number
=
0
;
const
loggers
=
new
Map
<
string
,
Logger
>
();
const
loggers
=
new
Map
<
string
,
Logger
>
();
...
@@ -70,7 +70,8 @@ export class Logger {
...
@@ -70,7 +70,8 @@ export class Logger {
}
}
private
log
(
level
:
number
,
args
:
any
[]):
void
{
private
log
(
level
:
number
,
args
:
any
[]):
void
{
if
(
level
<
logLevel
||
logFile
===
null
)
{
const
logFile
:
Writable
|
undefined
=
(
global
as
any
).
logFile
;
if
(
level
<
logLevel
||
logFile
===
undefined
)
{
return
;
return
;
}
}
...
@@ -80,20 +81,7 @@ export class Logger {
...
@@ -80,20 +81,7 @@ export class Logger {
const
levelName
=
levelNames
.
has
(
level
)
?
levelNames
.
get
(
level
)
:
level
.
toString
();
const
levelName
=
levelNames
.
has
(
level
)
?
levelNames
.
get
(
level
)
:
level
.
toString
();
const
words
=
[];
const
message
=
args
.
map
(
arg
=>
(
typeof
arg
===
'
string
'
?
arg
:
util
.
inspect
(
arg
))).
join
(
'
'
);
for
(
const
arg
of
args
)
{
if
(
arg
===
undefined
)
{
words
.
push
(
'
undefined
'
);
}
else
if
(
arg
===
null
)
{
words
.
push
(
'
null
'
);
}
else
if
(
typeof
arg
===
'
object
'
)
{
const
json
=
JSON
.
stringify
(
arg
);
words
.
push
(
json
===
undefined
?
arg
:
json
);
}
else
{
words
.
push
(
arg
);
}
}
const
message
=
words
.
join
(
'
'
);
const
record
=
`[
${
time
}
]
${
levelName
}
(
${
this
.
name
}
)
${
message
}
\n`
;
const
record
=
`[
${
time
}
]
${
levelName
}
(
${
this
.
name
}
)
${
message
}
\n`
;
logFile
.
write
(
record
);
logFile
.
write
(
record
);
...
@@ -124,7 +112,7 @@ export function setLogLevel(levelName: string): void {
...
@@ -124,7 +112,7 @@ export function setLogLevel(levelName: string): void {
}
}
export
function
startLogging
(
logPath
:
string
):
void
{
export
function
startLogging
(
logPath
:
string
):
void
{
logFile
=
fs
.
createWriteStream
(
logPath
,
{
(
global
as
any
).
logFile
=
fs
.
createWriteStream
(
logPath
,
{
flags
:
'
a+
'
,
flags
:
'
a+
'
,
encoding
:
'
utf8
'
,
encoding
:
'
utf8
'
,
autoClose
:
true
autoClose
:
true
...
@@ -132,8 +120,8 @@ export function startLogging(logPath: string): void {
...
@@ -132,8 +120,8 @@ export function startLogging(logPath: string): void {
}
}
export
function
stopLogging
():
void
{
export
function
stopLogging
():
void
{
if
(
logFile
!==
null
)
{
if
(
(
global
as
any
).
logFile
!==
undefined
)
{
logFile
.
end
();
(
global
as
any
).
logFile
.
end
();
logFile
=
null
;
(
global
as
any
).
logFile
=
undefined
;
}
}
}
}
ts/nni_manager/common/nniConfig.ts
0 → 100644
View file @
277e63f2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
path
from
'
path
'
;
import
{
promisify
}
from
'
util
'
;
import
{
runPythonScript
}
from
'
./pythonScript
'
;
export
interface
CustomEnvironmentServiceConfig
{
name
:
string
;
nodeModulePath
:
string
;
nodeClassName
:
string
;
}
const
readFile
=
promisify
(
fs
.
readFile
);
async
function
readConfigFile
(
fileName
:
string
):
Promise
<
string
>
{
const
script
=
'
import nni.runtime.config ; print(nni.runtime.config.get_config_directory())
'
;
const
configDir
=
(
await
runPythonScript
(
script
)).
trim
();
const
stream
=
await
readFile
(
path
.
join
(
configDir
,
fileName
));
return
stream
.
toString
();
}
export
async
function
getCustomEnvironmentServiceConfig
(
name
:
string
):
Promise
<
CustomEnvironmentServiceConfig
|
null
>
{
const
configJson
=
await
readConfigFile
(
'
training_services.json
'
);
const
config
=
JSON
.
parse
(
configJson
);
if
(
config
[
name
]
===
undefined
)
{
return
null
;
}
return
{
name
,
nodeModulePath
:
config
[
name
].
nodeModulePath
as
string
,
nodeClassName
:
config
[
name
].
nodeClassName
as
string
,
}
}
ts/nni_manager/common/pythonScript.ts
0 → 100644
View file @
277e63f2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
{
spawn
}
from
'
child_process
'
;
import
{
Logger
,
getLogger
}
from
'
./log
'
;
const
python
=
process
.
platform
===
'
win32
'
?
'
python.exe
'
:
'
python3
'
;
export
async
function
runPythonScript
(
script
:
string
,
logger
?:
Logger
):
Promise
<
string
>
{
const
proc
=
spawn
(
python
,
[
'
-c
'
,
script
]);
const
procPromise
=
new
Promise
<
void
>
((
resolve
,
reject
)
=>
{
proc
.
on
(
'
error
'
,
(
err
:
Error
)
=>
{
reject
(
err
);
});
proc
.
on
(
'
exit
'
,
()
=>
{
resolve
();
});
});
await
procPromise
;
const
stdout
=
proc
.
stdout
.
read
().
toString
();
const
stderr
=
proc
.
stderr
.
read
().
toString
();
if
(
stderr
)
{
if
(
logger
===
undefined
)
{
logger
=
getLogger
();
}
logger
.
warning
(
'
python script has stderr.
'
);
logger
.
warning
(
'
script:
'
,
script
);
logger
.
warning
(
'
stderr:
'
,
stderr
);
}
return
stdout
;
}
ts/nni_manager/common/utils.ts
View file @
277e63f2
...
@@ -25,8 +25,7 @@ import { ExperimentManager } from './experimentManager';
...
@@ -25,8 +25,7 @@ import { ExperimentManager } from './experimentManager';
import
{
HyperParameters
,
TrainingService
,
TrialJobStatus
}
from
'
./trainingService
'
;
import
{
HyperParameters
,
TrainingService
,
TrialJobStatus
}
from
'
./trainingService
'
;
function
getExperimentRootDir
():
string
{
function
getExperimentRootDir
():
string
{
return
getExperimentStartupInfo
()
return
getExperimentStartupInfo
().
getLogDir
();
.
getLogDir
();
}
}
function
getLogDir
():
string
{
function
getLogDir
():
string
{
...
@@ -34,8 +33,7 @@ function getLogDir(): string {
...
@@ -34,8 +33,7 @@ function getLogDir(): string {
}
}
function
getLogLevel
():
string
{
function
getLogLevel
():
string
{
return
getExperimentStartupInfo
()
return
getExperimentStartupInfo
().
getLogLevel
();
.
getLogLevel
();
}
}
function
getDefaultDatabaseDir
():
string
{
function
getDefaultDatabaseDir
():
string
{
...
@@ -481,6 +479,11 @@ async function getFreePort(host: string, start: number, end: number): Promise<nu
...
@@ -481,6 +479,11 @@ async function getFreePort(host: string, start: number, end: number): Promise<nu
}
}
}
}
export
function
importModule
(
modulePath
:
string
):
any
{
module
.
paths
.
unshift
(
path
.
dirname
(
modulePath
));
return
require
(
path
.
basename
(
modulePath
));
}
export
{
export
{
countFilesRecursively
,
validateFileNameRecursively
,
generateParamFileName
,
getMsgDispatcherCommand
,
getCheckpointDir
,
getExperimentsInfoPath
,
countFilesRecursively
,
validateFileNameRecursively
,
generateParamFileName
,
getMsgDispatcherCommand
,
getCheckpointDir
,
getExperimentsInfoPath
,
getLogDir
,
getExperimentRootDir
,
getJobCancelStatus
,
getDefaultDatabaseDir
,
getIPV4Address
,
unixPathJoin
,
withLockSync
,
getFreePort
,
isPortOpen
,
getLogDir
,
getExperimentRootDir
,
getJobCancelStatus
,
getDefaultDatabaseDir
,
getIPV4Address
,
unixPathJoin
,
withLockSync
,
getFreePort
,
isPortOpen
,
...
...
ts/nni_manager/core/nnimanager.ts
View file @
277e63f2
...
@@ -445,10 +445,7 @@ class NNIManager implements Manager {
...
@@ -445,10 +445,7 @@ class NNIManager implements Manager {
throw
new
Error
(
'
Cannot detect training service platform
'
);
throw
new
Error
(
'
Cannot detect training service platform
'
);
}
}
if
([
'
remote
'
,
'
pai
'
,
'
aml
'
,
'
hybrid
'
].
includes
(
platform
))
{
if
(
platform
===
'
local
'
)
{
const
module_
=
await
import
(
'
../training_service/reusable/routerTrainingService
'
);
return
new
module_
.
RouterTrainingService
(
config
);
}
else
if
(
platform
===
'
local
'
)
{
const
module_
=
await
import
(
'
../training_service/local/localTrainingService
'
);
const
module_
=
await
import
(
'
../training_service/local/localTrainingService
'
);
return
new
module_
.
LocalTrainingService
(
config
);
return
new
module_
.
LocalTrainingService
(
config
);
}
else
if
(
platform
===
'
kubeflow
'
)
{
}
else
if
(
platform
===
'
kubeflow
'
)
{
...
@@ -460,6 +457,9 @@ class NNIManager implements Manager {
...
@@ -460,6 +457,9 @@ class NNIManager implements Manager {
}
else
if
(
platform
===
'
adl
'
)
{
}
else
if
(
platform
===
'
adl
'
)
{
const
module_
=
await
import
(
'
../training_service/kubernetes/adl/adlTrainingService
'
);
const
module_
=
await
import
(
'
../training_service/kubernetes/adl/adlTrainingService
'
);
return
new
module_
.
AdlTrainingService
();
return
new
module_
.
AdlTrainingService
();
}
else
{
const
module_
=
await
import
(
'
../training_service/reusable/routerTrainingService
'
);
return
await
module_
.
RouterTrainingService
.
construct
(
config
);
}
}
throw
new
Error
(
`Unsupported training service platform "
${
platform
}
"`
);
throw
new
Error
(
`Unsupported training service platform "
${
platform
}
"`
);
...
...
ts/nni_manager/main.ts
View file @
277e63f2
...
@@ -83,11 +83,6 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
...
@@ -83,11 +83,6 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
const
port
:
number
=
parseInt
(
strPort
,
10
);
const
port
:
number
=
parseInt
(
strPort
,
10
);
const
mode
:
string
=
parseArg
([
'
--mode
'
,
'
-m
'
]);
const
mode
:
string
=
parseArg
([
'
--mode
'
,
'
-m
'
]);
if
(
!
[
'
local
'
,
'
remote
'
,
'
pai
'
,
'
kubeflow
'
,
'
frameworkcontroller
'
,
'
dlts
'
,
'
aml
'
,
'
adl
'
,
'
hybrid
'
].
includes
(
mode
))
{
console
.
log
(
`FATAL: unknown mode:
${
mode
}
`
);
usage
();
process
.
exit
(
1
);
}
const
startMode
:
string
=
parseArg
([
'
--start_mode
'
,
'
-s
'
]);
const
startMode
:
string
=
parseArg
([
'
--start_mode
'
,
'
-s
'
]);
if
(
!
[
ExperimentStartUpMode
.
NEW
,
ExperimentStartUpMode
.
RESUME
].
includes
(
startMode
))
{
if
(
!
[
ExperimentStartUpMode
.
NEW
,
ExperimentStartUpMode
.
RESUME
].
includes
(
startMode
))
{
...
...
ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts
View file @
277e63f2
...
@@ -6,9 +6,7 @@
...
@@ -6,9 +6,7 @@
import
*
as
fs
from
'
fs
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
path
from
'
path
'
;
import
*
as
path
from
'
path
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
getExperimentRootDir
}
from
'
../../../common/utils
'
;
import
{
ExperimentConfig
,
AmlConfig
,
flattenConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
ExperimentConfig
,
AmlConfig
,
flattenConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
validateCodeDir
}
from
'
../../common/util
'
;
import
{
validateCodeDir
}
from
'
../../common/util
'
;
import
{
AMLClient
}
from
'
../aml/amlClient
'
;
import
{
AMLClient
}
from
'
../aml/amlClient
'
;
...
@@ -31,10 +29,10 @@ export class AMLEnvironmentService extends EnvironmentService {
...
@@ -31,10 +29,10 @@ export class AMLEnvironmentService extends EnvironmentService {
private
experimentRootDir
:
string
;
private
experimentRootDir
:
string
;
private
config
:
FlattenAmlConfig
;
private
config
:
FlattenAmlConfig
;
constructor
(
config
:
ExperimentConfig
)
{
constructor
(
experimentRootDir
:
string
,
experimentId
:
string
,
config
:
ExperimentConfig
)
{
super
();
super
();
this
.
experimentId
=
getE
xperimentId
()
;
this
.
experimentId
=
e
xperimentId
;
this
.
experimentRootDir
=
getE
xperimentRootDir
()
;
this
.
experimentRootDir
=
e
xperimentRootDir
;
this
.
config
=
flattenConfig
(
config
,
'
aml
'
);
this
.
config
=
flattenConfig
(
config
,
'
aml
'
);
validateCodeDir
(
this
.
config
.
trialCodeDirectory
);
validateCodeDir
(
this
.
config
.
trialCodeDirectory
);
}
}
...
...
ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts
View file @
277e63f2
...
@@ -4,20 +4,31 @@ import { LocalEnvironmentService } from './localEnvironmentService';
...
@@ -4,20 +4,31 @@ import { LocalEnvironmentService } from './localEnvironmentService';
import
{
RemoteEnvironmentService
}
from
'
./remoteEnvironmentService
'
;
import
{
RemoteEnvironmentService
}
from
'
./remoteEnvironmentService
'
;
import
{
EnvironmentService
}
from
'
../environment
'
;
import
{
EnvironmentService
}
from
'
../environment
'
;
import
{
ExperimentConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
ExperimentConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
getCustomEnvironmentServiceConfig
}
from
'
../../../common/nniConfig
'
;
import
{
getExperimentRootDir
,
importModule
}
from
'
../../../common/utils
'
;
export
class
EnvironmentServiceFactory
{
public
static
createEnvironmentService
(
name
:
string
,
config
:
ExperimentConfig
):
EnvironmentService
{
export
async
function
createEnvironmentService
(
name
:
string
,
config
:
ExperimentConfig
):
Promise
<
EnvironmentService
>
{
switch
(
name
)
{
const
expId
=
getExperimentId
();
case
'
local
'
:
const
rootDir
=
getExperimentRootDir
();
return
new
LocalEnvironmentService
(
config
);
case
'
remote
'
:
switch
(
name
)
{
return
new
RemoteEnvironmentService
(
config
);
case
'
local
'
:
case
'
aml
'
:
return
new
LocalEnvironmentService
(
rootDir
,
expId
,
config
);
return
new
AMLEnvironmentService
(
config
);
case
'
remote
'
:
case
'
openpai
'
:
return
new
RemoteEnvironmentService
(
rootDir
,
expId
,
config
);
return
new
OpenPaiEnvironmentService
(
config
);
case
'
aml
'
:
default
:
return
new
AMLEnvironmentService
(
rootDir
,
expId
,
config
);
throw
new
Error
(
`
${
name
}
not supported!`
);
case
'
openpai
'
:
}
return
new
OpenPaiEnvironmentService
(
rootDir
,
expId
,
config
);
}
const
esConfig
=
await
getCustomEnvironmentServiceConfig
(
name
);
if
(
esConfig
===
null
)
{
throw
new
Error
(
`
${
name
}
is not a supported training service!`
);
}
}
const
esModule
=
importModule
(
esConfig
.
nodeModulePath
);
const
esClass
=
esModule
[
esConfig
.
nodeClassName
]
as
any
;
return
new
esClass
(
rootDir
,
expId
,
config
);
}
}
ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts
View file @
277e63f2
...
@@ -7,11 +7,10 @@ import * as fs from 'fs';
...
@@ -7,11 +7,10 @@ import * as fs from 'fs';
import
*
as
path
from
'
path
'
;
import
*
as
path
from
'
path
'
;
import
*
as
tkill
from
'
tree-kill
'
;
import
*
as
tkill
from
'
tree-kill
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
ExperimentConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
ExperimentConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
EnvironmentInformation
,
EnvironmentService
}
from
'
../environment
'
;
import
{
EnvironmentInformation
,
EnvironmentService
}
from
'
../environment
'
;
import
{
getExperimentRootDir
,
isAlive
,
getNewLine
}
from
'
../../../common/utils
'
;
import
{
isAlive
,
getNewLine
}
from
'
../../../common/utils
'
;
import
{
execMkdir
,
runScript
,
getScriptName
,
execCopydir
}
from
'
../../common/util
'
;
import
{
execMkdir
,
runScript
,
getScriptName
,
execCopydir
}
from
'
../../common/util
'
;
import
{
SharedStorageService
}
from
'
../sharedStorage
'
import
{
SharedStorageService
}
from
'
../sharedStorage
'
...
@@ -22,10 +21,10 @@ export class LocalEnvironmentService extends EnvironmentService {
...
@@ -22,10 +21,10 @@ export class LocalEnvironmentService extends EnvironmentService {
private
experimentRootDir
:
string
;
private
experimentRootDir
:
string
;
private
experimentId
:
string
;
private
experimentId
:
string
;
constructor
(
_config
:
ExperimentConfig
)
{
constructor
(
experimentRootDir
:
string
,
experimentId
:
string
,
_config
:
ExperimentConfig
)
{
super
();
super
();
this
.
experimentId
=
getE
xperimentId
()
;
this
.
experimentId
=
e
xperimentId
;
this
.
experimentRootDir
=
getE
xperimentRootDir
()
;
this
.
experimentRootDir
=
e
xperimentRootDir
;
}
}
public
get
environmentMaintenceLoopInterval
():
number
{
public
get
environmentMaintenceLoopInterval
():
number
{
...
@@ -110,8 +109,6 @@ export class LocalEnvironmentService extends EnvironmentService {
...
@@ -110,8 +109,6 @@ export class LocalEnvironmentService extends EnvironmentService {
const
sharedStorageService
=
component
.
get
<
SharedStorageService
>
(
SharedStorageService
);
const
sharedStorageService
=
component
.
get
<
SharedStorageService
>
(
SharedStorageService
);
if
(
environment
.
useSharedStorage
&&
sharedStorageService
.
canLocalMounted
)
{
if
(
environment
.
useSharedStorage
&&
sharedStorageService
.
canLocalMounted
)
{
this
.
experimentRootDir
=
sharedStorageService
.
localWorkingRoot
;
this
.
experimentRootDir
=
sharedStorageService
.
localWorkingRoot
;
}
else
{
this
.
experimentRootDir
=
getExperimentRootDir
();
}
}
const
localEnvCodeFolder
:
string
=
path
.
join
(
this
.
experimentRootDir
,
"
envs
"
);
const
localEnvCodeFolder
:
string
=
path
.
join
(
this
.
experimentRootDir
,
"
envs
"
);
if
(
environment
.
useSharedStorage
&&
!
sharedStorageService
.
canLocalMounted
)
{
if
(
environment
.
useSharedStorage
&&
!
sharedStorageService
.
canLocalMounted
)
{
...
...
ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts
View file @
277e63f2
...
@@ -7,7 +7,6 @@ import * as yaml from 'js-yaml';
...
@@ -7,7 +7,6 @@ import * as yaml from 'js-yaml';
import
*
as
request
from
'
request
'
;
import
*
as
request
from
'
request
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
ExperimentConfig
,
OpenpaiConfig
,
flattenConfig
,
toMegaBytes
}
from
'
../../../common/experimentConfig
'
;
import
{
ExperimentConfig
,
OpenpaiConfig
,
flattenConfig
,
toMegaBytes
}
from
'
../../../common/experimentConfig
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
PAIClusterConfig
}
from
'
../../pai/paiConfig
'
;
import
{
PAIClusterConfig
}
from
'
../../pai/paiConfig
'
;
...
@@ -32,9 +31,9 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
...
@@ -32,9 +31,9 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
private
experimentId
:
string
;
private
experimentId
:
string
;
private
config
:
FlattenOpenpaiConfig
;
private
config
:
FlattenOpenpaiConfig
;
constructor
(
config
:
ExperimentConfig
)
{
constructor
(
_experimentRootDir
:
string
,
experimentId
:
string
,
config
:
ExperimentConfig
)
{
super
();
super
();
this
.
experimentId
=
getE
xperimentId
()
;
this
.
experimentId
=
e
xperimentId
;
this
.
config
=
flattenConfig
(
config
,
'
openpai
'
);
this
.
config
=
flattenConfig
(
config
,
'
openpai
'
);
this
.
paiToken
=
this
.
config
.
token
;
this
.
paiToken
=
this
.
config
.
token
;
this
.
protocol
=
this
.
config
.
host
.
toLowerCase
().
startsWith
(
'
https://
'
)
?
'
https
'
:
'
http
'
;
this
.
protocol
=
this
.
config
.
host
.
toLowerCase
().
startsWith
(
'
https://
'
)
?
'
https
'
:
'
http
'
;
...
...
ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts
View file @
277e63f2
...
@@ -6,10 +6,9 @@
...
@@ -6,10 +6,9 @@
import
*
as
fs
from
'
fs
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
path
from
'
path
'
;
import
*
as
path
from
'
path
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
EnvironmentInformation
,
EnvironmentService
}
from
'
../environment
'
;
import
{
EnvironmentInformation
,
EnvironmentService
}
from
'
../environment
'
;
import
{
getExperimentRootDir
,
getLogLevel
}
from
'
../../../common/utils
'
;
import
{
getLogLevel
}
from
'
../../../common/utils
'
;
import
{
ExperimentConfig
,
RemoteConfig
,
RemoteMachineConfig
,
flattenConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
ExperimentConfig
,
RemoteConfig
,
RemoteMachineConfig
,
flattenConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
execMkdir
}
from
'
../../common/util
'
;
import
{
execMkdir
}
from
'
../../common/util
'
;
import
{
ExecutorManager
}
from
'
../../remote_machine/remoteMachineData
'
;
import
{
ExecutorManager
}
from
'
../../remote_machine/remoteMachineData
'
;
...
@@ -33,14 +32,13 @@ export class RemoteEnvironmentService extends EnvironmentService {
...
@@ -33,14 +32,13 @@ export class RemoteEnvironmentService extends EnvironmentService {
private
experimentId
:
string
;
private
experimentId
:
string
;
private
config
:
FlattenRemoteConfig
;
private
config
:
FlattenRemoteConfig
;
constructor
(
config
:
ExperimentConfig
)
{
constructor
(
experimentRootDir
:
string
,
experimentId
:
string
,
config
:
ExperimentConfig
)
{
super
();
super
();
this
.
experimentId
=
getE
xperimentId
()
;
this
.
experimentId
=
e
xperimentId
;
this
.
environmentExecutorManagerMap
=
new
Map
<
string
,
ExecutorManager
>
();
this
.
environmentExecutorManagerMap
=
new
Map
<
string
,
ExecutorManager
>
();
this
.
machineExecutorManagerMap
=
new
Map
<
RemoteMachineConfig
,
ExecutorManager
>
();
this
.
machineExecutorManagerMap
=
new
Map
<
RemoteMachineConfig
,
ExecutorManager
>
();
this
.
remoteMachineMetaOccupiedMap
=
new
Map
<
RemoteMachineConfig
,
boolean
>
();
this
.
remoteMachineMetaOccupiedMap
=
new
Map
<
RemoteMachineConfig
,
boolean
>
();
this
.
experimentRootDir
=
getExperimentRootDir
();
this
.
experimentRootDir
=
experimentRootDir
;
this
.
experimentId
=
getExperimentId
();
this
.
log
=
getLogger
();
this
.
log
=
getLogger
();
this
.
config
=
flattenConfig
(
config
,
'
remote
'
);
this
.
config
=
flattenConfig
(
config
,
'
remote
'
);
...
@@ -103,10 +101,10 @@ export class RemoteEnvironmentService extends EnvironmentService {
...
@@ -103,10 +101,10 @@ export class RemoteEnvironmentService extends EnvironmentService {
// Create root working directory after executor is ready
// Create root working directory after executor is ready
const
nniRootDir
:
string
=
executor
.
joinPath
(
executor
.
getTempPath
(),
'
nni-experiments
'
);
const
nniRootDir
:
string
=
executor
.
joinPath
(
executor
.
getTempPath
(),
'
nni-experiments
'
);
await
executor
.
createFolder
(
executor
.
getRemoteExperimentRootDir
(
getE
xperimentId
()
));
await
executor
.
createFolder
(
executor
.
getRemoteExperimentRootDir
(
this
.
e
xperimentId
));
// the directory to store temp scripts in remote machine
// the directory to store temp scripts in remote machine
const
remoteGpuScriptCollectorDir
:
string
=
executor
.
getRemoteScriptsPath
(
getE
xperimentId
()
);
const
remoteGpuScriptCollectorDir
:
string
=
executor
.
getRemoteScriptsPath
(
this
.
e
xperimentId
);
// clean up previous result.
// clean up previous result.
await
executor
.
createFolder
(
remoteGpuScriptCollectorDir
,
true
);
await
executor
.
createFolder
(
remoteGpuScriptCollectorDir
,
true
);
...
@@ -245,7 +243,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
...
@@ -245,7 +243,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
throw
new
Error
(
`Mount shared storage on remote machine failed.\n ERROR:
${
result
.
stderr
}
`
);
throw
new
Error
(
`Mount shared storage on remote machine failed.\n ERROR:
${
result
.
stderr
}
`
);
}
}
}
else
{
}
else
{
this
.
remoteExperimentRootDir
=
executor
.
getRemoteExperimentRootDir
(
getE
xperimentId
()
);
this
.
remoteExperimentRootDir
=
executor
.
getRemoteExperimentRootDir
(
this
.
e
xperimentId
);
}
}
environment
.
command
=
await
this
.
getScript
(
environment
);
environment
.
command
=
await
this
.
getScript
(
environment
);
...
...
ts/nni_manager/training_service/reusable/routerTrainingService.ts
View file @
277e63f2
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
'
use strict
'
;
'
use strict
'
;
import
*
as
component
from
'
../../common/component
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
ExperimentConfig
,
RemoteConfig
,
OpenpaiConfig
}
from
'
../../common/experimentConfig
'
;
import
{
ExperimentConfig
,
RemoteConfig
,
OpenpaiConfig
}
from
'
../../common/experimentConfig
'
;
...
@@ -18,23 +17,27 @@ import { TrialDispatcher } from './trialDispatcher';
...
@@ -18,23 +17,27 @@ import { TrialDispatcher } from './trialDispatcher';
* It's a intermedia implementation to support reusable training service.
* It's a intermedia implementation to support reusable training service.
* The final goal is to support reusable training job in higher level than training service.
* The final goal is to support reusable training job in higher level than training service.
*/
*/
@
component
.
Singleton
class
RouterTrainingService
implements
TrainingService
{
class
RouterTrainingService
implements
TrainingService
{
pr
otected
readonly
log
:
Logger
;
pr
ivate
log
!
:
Logger
;
private
internalTrainingService
:
TrainingService
;
private
internalTrainingService
!
:
TrainingService
;
constructor
(
config
:
ExperimentConfig
)
{
public
static
async
construct
(
config
:
ExperimentConfig
):
Promise
<
RouterTrainingService
>
{
this
.
log
=
getLogger
();
const
instance
=
new
RouterTrainingService
();
instance
.
log
=
getLogger
(
'
RouterTrainingService
'
);
const
platform
=
Array
.
isArray
(
config
.
trainingService
)
?
'
hybrid
'
:
config
.
trainingService
.
platform
;
const
platform
=
Array
.
isArray
(
config
.
trainingService
)
?
'
hybrid
'
:
config
.
trainingService
.
platform
;
if
(
platform
===
'
remote
'
&&
!
(
<
RemoteConfig
>
config
.
trainingService
).
reuseMode
)
{
if
(
platform
===
'
remote
'
&&
!
(
<
RemoteConfig
>
config
.
trainingService
).
reuseMode
)
{
this
.
internalTrainingService
=
new
RemoteMachineTrainingService
(
config
);
instance
.
internalTrainingService
=
new
RemoteMachineTrainingService
(
config
);
}
else
if
(
platform
===
'
openpai
'
&&
!
(
<
OpenpaiConfig
>
config
.
trainingService
).
reuseMode
)
{
}
else
if
(
platform
===
'
openpai
'
&&
!
(
<
OpenpaiConfig
>
config
.
trainingService
).
reuseMode
)
{
this
.
internalTrainingService
=
new
PAITrainingService
(
config
);
instance
.
internalTrainingService
=
new
PAITrainingService
(
config
);
}
else
{
}
else
{
this
.
internalTrainingService
=
new
TrialDispatcher
(
config
);
instance
.
internalTrainingService
=
await
TrialDispatcher
.
construct
(
config
);
}
}
return
instance
;
}
}
// eslint-disable-next-line @typescript-eslint/no-empty-function
private
constructor
()
{
}
public
async
listTrialJobs
():
Promise
<
TrialJobDetail
[]
>
{
public
async
listTrialJobs
():
Promise
<
TrialJobDetail
[]
>
{
if
(
this
.
internalTrainingService
===
undefined
)
{
if
(
this
.
internalTrainingService
===
undefined
)
{
throw
new
Error
(
"
TrainingService is not assigned!
"
);
throw
new
Error
(
"
TrainingService is not assigned!
"
);
...
...
ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts
View file @
277e63f2
...
@@ -203,7 +203,7 @@ describe('Unit Test for TrialDispatcher', () => {
...
@@ -203,7 +203,7 @@ describe('Unit Test for TrialDispatcher', () => {
});
});
beforeEach
(
async
()
=>
{
beforeEach
(
async
()
=>
{
trialDispatcher
=
new
TrialDispatcher
(
config
);
trialDispatcher
=
await
TrialDispatcher
.
construct
(
config
);
// set ut environment
// set ut environment
let
environmentServiceList
:
EnvironmentService
[]
=
[];
let
environmentServiceList
:
EnvironmentService
[]
=
[];
...
...
ts/nni_manager/training_service/reusable/trialDispatcher.ts
View file @
277e63f2
...
@@ -24,7 +24,7 @@ import { TrialConfig } from '../common/trialConfig';
...
@@ -24,7 +24,7 @@ import { TrialConfig } from '../common/trialConfig';
import
{
validateCodeDir
}
from
'
../common/util
'
;
import
{
validateCodeDir
}
from
'
../common/util
'
;
import
{
Command
,
CommandChannel
}
from
'
./commandChannel
'
;
import
{
Command
,
CommandChannel
}
from
'
./commandChannel
'
;
import
{
EnvironmentInformation
,
EnvironmentService
,
NodeInformation
,
RunnerSettings
,
TrialGpuSummary
}
from
'
./environment
'
;
import
{
EnvironmentInformation
,
EnvironmentService
,
NodeInformation
,
RunnerSettings
,
TrialGpuSummary
}
from
'
./environment
'
;
import
{
EnvironmentService
Factory
}
from
'
./environments/environmentServiceFactory
'
;
import
{
create
EnvironmentService
}
from
'
./environments/environmentServiceFactory
'
;
import
{
GpuScheduler
}
from
'
./gpuScheduler
'
;
import
{
GpuScheduler
}
from
'
./gpuScheduler
'
;
import
{
MountedStorageService
}
from
'
./storages/mountedStorageService
'
;
import
{
MountedStorageService
}
from
'
./storages/mountedStorageService
'
;
import
{
StorageService
}
from
'
./storageService
'
;
import
{
StorageService
}
from
'
./storageService
'
;
...
@@ -39,20 +39,20 @@ import { TrialDetail } from './trial';
...
@@ -39,20 +39,20 @@ import { TrialDetail } from './trial';
**/
**/
@
component
.
Singleton
@
component
.
Singleton
class
TrialDispatcher
implements
TrainingService
{
class
TrialDispatcher
implements
TrainingService
{
private
readonly
log
:
Logger
;
private
log
:
Logger
;
private
readonly
isDeveloping
:
boolean
=
false
;
private
isDeveloping
:
boolean
=
false
;
private
stopping
:
boolean
=
false
;
private
stopping
:
boolean
=
false
;
private
readonly
metricsEmitter
:
EventEmitter
;
private
metricsEmitter
:
EventEmitter
;
private
readonly
experimentId
:
string
;
private
experimentId
:
string
;
private
readonly
experimentRootDir
:
string
;
private
experimentRootDir
:
string
;
private
enableVersionCheck
:
boolean
=
true
;
private
enableVersionCheck
:
boolean
=
true
;
private
trialConfig
:
TrialConfig
|
undefined
;
private
trialConfig
:
TrialConfig
|
undefined
;
private
readonly
trials
:
Map
<
string
,
TrialDetail
>
;
private
trials
:
Map
<
string
,
TrialDetail
>
;
private
readonly
environments
:
Map
<
string
,
EnvironmentInformation
>
;
private
environments
:
Map
<
string
,
EnvironmentInformation
>
;
// make public for ut
// make public for ut
public
environmentServiceList
:
EnvironmentService
[]
=
[];
public
environmentServiceList
:
EnvironmentService
[]
=
[];
public
commandChannelSet
:
Set
<
CommandChannel
>
;
public
commandChannelSet
:
Set
<
CommandChannel
>
;
...
@@ -82,8 +82,14 @@ class TrialDispatcher implements TrainingService {
...
@@ -82,8 +82,14 @@ class TrialDispatcher implements TrainingService {
private
config
:
ExperimentConfig
;
private
config
:
ExperimentConfig
;
constructor
(
config
:
ExperimentConfig
)
{
public
static
async
construct
(
config
:
ExperimentConfig
):
Promise
<
TrialDispatcher
>
{
this
.
log
=
getLogger
();
const
instance
=
new
TrialDispatcher
(
config
);
await
instance
.
asyncConstructor
(
config
);
return
instance
;
}
private
constructor
(
config
:
ExperimentConfig
)
{
this
.
log
=
getLogger
(
'
TrialDispatcher
'
);
this
.
trials
=
new
Map
<
string
,
TrialDetail
>
();
this
.
trials
=
new
Map
<
string
,
TrialDetail
>
();
this
.
environments
=
new
Map
<
string
,
EnvironmentInformation
>
();
this
.
environments
=
new
Map
<
string
,
EnvironmentInformation
>
();
this
.
metricsEmitter
=
new
EventEmitter
();
this
.
metricsEmitter
=
new
EventEmitter
();
...
@@ -109,18 +115,14 @@ class TrialDispatcher implements TrainingService {
...
@@ -109,18 +115,14 @@ class TrialDispatcher implements TrainingService {
if
(
this
.
enableGpuScheduler
)
{
if
(
this
.
enableGpuScheduler
)
{
this
.
log
.
info
(
`TrialDispatcher: GPU scheduler is enabled.`
)
this
.
log
.
info
(
`TrialDispatcher: GPU scheduler is enabled.`
)
}
}
}
validateCodeDir
(
config
.
trialCodeDirectory
);
private
async
asyncConstructor
(
config
:
ExperimentConfig
):
Promise
<
void
>
{
await
validateCodeDir
(
config
.
trialCodeDirectory
);
if
(
Array
.
isArray
(
config
.
trainingService
))
{
const
serviceConfigs
=
Array
.
isArray
(
config
.
trainingService
)
?
config
.
trainingService
:
[
config
.
trainingService
];
config
.
trainingService
.
forEach
(
trainingService
=>
{
const
servicePromises
=
serviceConfigs
.
map
(
serviceConfig
=>
createEnvironmentService
(
serviceConfig
.
platform
,
config
));
const
env
=
EnvironmentServiceFactory
.
createEnvironmentService
(
trainingService
.
platform
,
config
);
this
.
environmentServiceList
=
await
Promise
.
all
(
servicePromises
);
this
.
environmentServiceList
.
push
(
env
);
});
}
else
{
const
env
=
EnvironmentServiceFactory
.
createEnvironmentService
(
config
.
trainingService
.
platform
,
config
);
this
.
environmentServiceList
.
push
(
env
);
}
this
.
environmentMaintenceLoopInterval
=
Math
.
max
(
this
.
environmentMaintenceLoopInterval
=
Math
.
max
(
...
this
.
environmentServiceList
.
map
((
env
)
=>
env
.
environmentMaintenceLoopInterval
)
...
this
.
environmentServiceList
.
map
((
env
)
=>
env
.
environmentMaintenceLoopInterval
)
...
@@ -132,7 +134,7 @@ class TrialDispatcher implements TrainingService {
...
@@ -132,7 +134,7 @@ class TrialDispatcher implements TrainingService {
}
}
if
(
this
.
config
.
sharedStorage
!==
undefined
)
{
if
(
this
.
config
.
sharedStorage
!==
undefined
)
{
this
.
initializeSharedStorage
(
this
.
config
.
sharedStorage
);
await
this
.
initializeSharedStorage
(
this
.
config
.
sharedStorage
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment