Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
277e63f2
"testing/vscode:/vscode.git/clone" did not exist on "73a6cb8bfd1f6dfc6197b7ad9253719dd720d681"
Unverified
Commit
277e63f2
authored
May 27, 2021
by
liuzhe-lz
Committed by
GitHub
May 27, 2021
Browse files
Support 3rd-party training service (#3662)
parent
e349b440
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
277 additions
and
106 deletions
+277
-106
nni/experiment/config/util.py
nni/experiment/config/util.py
+14
-1
nni/runtime/default_config/training_services.json
nni/runtime/default_config/training_services.json
+1
-0
nni/runtime/platform/__init__.py
nni/runtime/platform/__init__.py
+1
-3
nni/tools/nnictl/nnictl.py
nni/tools/nnictl/nnictl.py
+18
-0
nni/tools/nnictl/ts_management.py
nni/tools/nnictl/ts_management.py
+77
-0
ts/nni_manager/common/log.ts
ts/nni_manager/common/log.ts
+8
-20
ts/nni_manager/common/nniConfig.ts
ts/nni_manager/common/nniConfig.ts
+37
-0
ts/nni_manager/common/pythonScript.ts
ts/nni_manager/common/pythonScript.ts
+33
-0
ts/nni_manager/common/utils.ts
ts/nni_manager/common/utils.ts
+7
-4
ts/nni_manager/core/nnimanager.ts
ts/nni_manager/core/nnimanager.ts
+4
-4
ts/nni_manager/main.ts
ts/nni_manager/main.ts
+0
-5
ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts
...ng_service/reusable/environments/amlEnvironmentService.ts
+3
-5
ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts
...ervice/reusable/environments/environmentServiceFactory.ts
+25
-14
ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts
..._service/reusable/environments/localEnvironmentService.ts
+4
-7
ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts
...ervice/reusable/environments/openPaiEnvironmentService.ts
+2
-3
ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts
...service/reusable/environments/remoteEnvironmentService.ts
+7
-9
ts/nni_manager/training_service/reusable/routerTrainingService.ts
...anager/training_service/reusable/routerTrainingService.ts
+12
-9
ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts
...er/training_service/reusable/test/trialDispatcher.test.ts
+1
-1
ts/nni_manager/training_service/reusable/trialDispatcher.ts
ts/nni_manager/training_service/reusable/trialDispatcher.ts
+23
-21
No files found.
nni/experiment/config/util.py
View file @
277e63f2
...
...
@@ -5,11 +5,15 @@
Miscellaneous utility functions.
"""
import
importlib
import
json
import
math
import
os.path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Union
,
List
import
nni.runtime.config
PathLike
=
Union
[
Path
,
str
]
def
case_insensitive
(
key_or_kwargs
:
Union
[
str
,
Dict
[
str
,
Any
]])
->
Union
[
str
,
Dict
[
str
,
Any
]]:
...
...
@@ -34,6 +38,14 @@ def training_service_config_factory(
config
:
Union
[
List
,
Dict
]
=
None
,
base_path
:
Optional
[
Path
]
=
None
):
# -> TrainingServiceConfig
from
.common
import
TrainingServiceConfig
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
custom_ts_config_path
=
nni
.
runtime
.
config
.
get_config_file
(
'training_services.json'
)
custom_ts_config
=
json
.
load
(
custom_ts_config_path
.
open
())
for
custom_ts_pkg
in
custom_ts_config
.
keys
():
pkg
=
importlib
.
import_module
(
custom_ts_pkg
)
_config_class
=
pkg
.
nni_training_service_info
.
config_class
ts_configs
=
[]
if
platform
is
not
None
:
assert
config
is
None
...
...
@@ -42,7 +54,8 @@ def training_service_config_factory(
if
cls
.
platform
in
platforms
:
ts_configs
.
append
(
cls
())
if
len
(
ts_configs
)
<
len
(
platforms
):
raise
RuntimeError
(
'There is unrecognized platform!'
)
bad
=
', '
.
join
(
set
(
platforms
)
-
set
(
ts_configs
))
raise
RuntimeError
(
f
'Bad training service platform:
{
bad
}
'
)
else
:
assert
config
is
not
None
supported_platforms
=
{
cls
.
platform
:
cls
for
cls
in
TrainingServiceConfig
.
__subclasses__
()}
...
...
nni/runtime/default_config/training_services.json
0 → 100644
View file @
277e63f2
{}
nni/runtime/platform/__init__.py
View file @
277e63f2
...
...
@@ -9,7 +9,5 @@ if trial_env_vars.NNI_PLATFORM is None:
from
.standalone
import
*
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
from
.test
import
*
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'dlts'
,
'aml'
,
'adl'
,
'hybrid'
):
from
.local
import
*
else
:
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
from
.local
import
*
nni/tools/nnictl/nnictl.py
View file @
277e63f2
...
...
@@ -16,6 +16,8 @@ from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment
save_experiment
,
load_experiment
from
.algo_management
import
algo_reg
,
algo_unreg
,
algo_show
,
algo_list
from
.constants
import
DEFAULT_REST_PORT
from
.
import
ts_management
init
(
autoreset
=
True
)
if
os
.
environ
.
get
(
'COVERAGE_PROCESS_START'
):
...
...
@@ -242,6 +244,22 @@ def parse_args():
parser_algo_list
=
parser_algo_subparsers
.
add_parser
(
'list'
,
help
=
'list registered algorithms'
)
parser_algo_list
.
set_defaults
(
func
=
algo_list
)
#parse trainingservice command
parser_ts
=
subparsers
.
add_parser
(
'trainingservice'
,
help
=
'control training service'
)
# add subparsers for parser_ts
parser_ts_subparsers
=
parser_ts
.
add_subparsers
()
parser_ts_reg
=
parser_ts_subparsers
.
add_parser
(
'register'
,
help
=
'register training service'
)
parser_ts_reg
.
add_argument
(
'--package'
,
dest
=
'package'
,
help
=
'package name'
,
required
=
True
)
parser_ts_reg
.
set_defaults
(
func
=
ts_management
.
register
)
parser_ts_unreg
=
parser_ts_subparsers
.
add_parser
(
'unregister'
,
help
=
'unregister training service'
)
parser_ts_unreg
.
add_argument
(
'--package'
,
dest
=
'package'
,
help
=
'package name'
,
required
=
True
)
parser_ts_unreg
.
set_defaults
(
func
=
ts_management
.
unregister
)
parser_ts_list
=
parser_ts_subparsers
.
add_parser
(
'list'
,
help
=
'list custom training services'
)
parser_ts_list
.
set_defaults
(
func
=
ts_management
.
list_services
)
# To show message that nnictl package command is replaced by nnictl algo, to be remove in the future release.
def
show_messsage_for_nnictl_package
(
args
):
print_error
(
'nnictl package command is replaced by nnictl algo, please run nnictl algo -h to show the usage'
)
...
...
nni/tools/nnictl/ts_management.py
0 → 100644
View file @
277e63f2
import
importlib
import
json
from
nni.runtime.config
import
get_config_file
from
.common_utils
import
print_error
,
print_green
_builtin_training_services
=
[
'local'
,
'remote'
,
'openpai'
,
'pai'
,
'aml'
,
'kubeflow'
,
'frameworkcontroller'
,
'adl'
,
]
def
register
(
args
):
if
args
.
package
in
_builtin_training_services
:
print_error
(
f
'
{
args
.
package
}
is a builtin training service'
)
return
try
:
module
=
importlib
.
import_module
(
args
.
package
)
except
Exception
:
print_error
(
f
'Cannot import package
{
args
.
package
}
'
)
return
try
:
info
=
module
.
nni_training_service_info
except
Exception
:
print_error
(
f
'Cannot read nni_training_service_info from
{
args
.
package
}
'
)
return
try
:
info
.
config_class
()
except
Exception
:
print_error
(
'Bad experiment config class'
)
return
try
:
service_config
=
{
'node_module_path'
:
info
.
node_module_path
,
'node_class_name'
:
info
.
node_class_name
,
}
json
.
dumps
(
service_config
)
except
Exception
:
print_error
(
'Bad node_module_path or bad node_class_name'
)
return
config
=
_load
()
update
=
args
.
package
in
config
config
[
args
.
package
]
=
service_config
_save
(
config
)
if
update
:
print_green
(
f
'Sucessfully updated
{
args
.
package
}
'
)
else
:
print_green
(
f
'Sucessfully registered
{
args
.
package
}
'
)
def
unregister
(
args
):
config
=
_load
()
if
args
.
package
not
in
config
:
print_error
(
f
'
{
args
.
package
}
is not a registered training service'
)
return
config
.
pop
(
args
.
package
,
None
)
_save
(
config
)
print_green
(
f
'Sucessfully unregistered
{
args
.
package
}
'
)
def
list_services
(
_
):
print
(
'
\n
'
.
join
(
_load
().
keys
()))
def
_load
():
return
json
.
load
(
get_config_file
(
'training_services.json'
).
open
())
def
_save
(
config
):
json
.
dump
(
config
,
get_config_file
(
'training_services.json'
).
open
(
'w'
),
indent
=
4
)
ts/nni_manager/common/log.ts
View file @
277e63f2
...
...
@@ -5,6 +5,7 @@
import
*
as
fs
from
'
fs
'
;
import
{
Writable
}
from
'
stream
'
;
import
*
as
util
from
'
util
'
;
/* log level constants */
...
...
@@ -28,7 +29,6 @@ const levelNames = new Map<number, string>([
/* global_ states */
let
logFile
:
Writable
|
null
=
null
;
let
logLevel
:
number
=
0
;
const
loggers
=
new
Map
<
string
,
Logger
>
();
...
...
@@ -70,7 +70,8 @@ export class Logger {
}
private
log
(
level
:
number
,
args
:
any
[]):
void
{
if
(
level
<
logLevel
||
logFile
===
null
)
{
const
logFile
:
Writable
|
undefined
=
(
global
as
any
).
logFile
;
if
(
level
<
logLevel
||
logFile
===
undefined
)
{
return
;
}
...
...
@@ -80,20 +81,7 @@ export class Logger {
const
levelName
=
levelNames
.
has
(
level
)
?
levelNames
.
get
(
level
)
:
level
.
toString
();
const
words
=
[];
for
(
const
arg
of
args
)
{
if
(
arg
===
undefined
)
{
words
.
push
(
'
undefined
'
);
}
else
if
(
arg
===
null
)
{
words
.
push
(
'
null
'
);
}
else
if
(
typeof
arg
===
'
object
'
)
{
const
json
=
JSON
.
stringify
(
arg
);
words
.
push
(
json
===
undefined
?
arg
:
json
);
}
else
{
words
.
push
(
arg
);
}
}
const
message
=
words
.
join
(
'
'
);
const
message
=
args
.
map
(
arg
=>
(
typeof
arg
===
'
string
'
?
arg
:
util
.
inspect
(
arg
))).
join
(
'
'
);
const
record
=
`[
${
time
}
]
${
levelName
}
(
${
this
.
name
}
)
${
message
}
\n`
;
logFile
.
write
(
record
);
...
...
@@ -124,7 +112,7 @@ export function setLogLevel(levelName: string): void {
}
export
function
startLogging
(
logPath
:
string
):
void
{
logFile
=
fs
.
createWriteStream
(
logPath
,
{
(
global
as
any
).
logFile
=
fs
.
createWriteStream
(
logPath
,
{
flags
:
'
a+
'
,
encoding
:
'
utf8
'
,
autoClose
:
true
...
...
@@ -132,8 +120,8 @@ export function startLogging(logPath: string): void {
}
export
function
stopLogging
():
void
{
if
(
logFile
!==
null
)
{
logFile
.
end
();
logFile
=
null
;
if
(
(
global
as
any
).
logFile
!==
undefined
)
{
(
global
as
any
).
logFile
.
end
();
(
global
as
any
).
logFile
=
undefined
;
}
}
ts/nni_manager/common/nniConfig.ts
0 → 100644
View file @
277e63f2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
path
from
'
path
'
;
import
{
promisify
}
from
'
util
'
;
import
{
runPythonScript
}
from
'
./pythonScript
'
;
export
interface
CustomEnvironmentServiceConfig
{
name
:
string
;
nodeModulePath
:
string
;
nodeClassName
:
string
;
}
const
readFile
=
promisify
(
fs
.
readFile
);
async
function
readConfigFile
(
fileName
:
string
):
Promise
<
string
>
{
const
script
=
'
import nni.runtime.config ; print(nni.runtime.config.get_config_directory())
'
;
const
configDir
=
(
await
runPythonScript
(
script
)).
trim
();
const
stream
=
await
readFile
(
path
.
join
(
configDir
,
fileName
));
return
stream
.
toString
();
}
export
async
function
getCustomEnvironmentServiceConfig
(
name
:
string
):
Promise
<
CustomEnvironmentServiceConfig
|
null
>
{
const
configJson
=
await
readConfigFile
(
'
training_services.json
'
);
const
config
=
JSON
.
parse
(
configJson
);
if
(
config
[
name
]
===
undefined
)
{
return
null
;
}
return
{
name
,
nodeModulePath
:
config
[
name
].
nodeModulePath
as
string
,
nodeClassName
:
config
[
name
].
nodeClassName
as
string
,
}
}
ts/nni_manager/common/pythonScript.ts
0 → 100644
View file @
277e63f2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
{
spawn
}
from
'
child_process
'
;
import
{
Logger
,
getLogger
}
from
'
./log
'
;
const
python
=
process
.
platform
===
'
win32
'
?
'
python.exe
'
:
'
python3
'
;
export
async
function
runPythonScript
(
script
:
string
,
logger
?:
Logger
):
Promise
<
string
>
{
const
proc
=
spawn
(
python
,
[
'
-c
'
,
script
]);
const
procPromise
=
new
Promise
<
void
>
((
resolve
,
reject
)
=>
{
proc
.
on
(
'
error
'
,
(
err
:
Error
)
=>
{
reject
(
err
);
});
proc
.
on
(
'
exit
'
,
()
=>
{
resolve
();
});
});
await
procPromise
;
const
stdout
=
proc
.
stdout
.
read
().
toString
();
const
stderr
=
proc
.
stderr
.
read
().
toString
();
if
(
stderr
)
{
if
(
logger
===
undefined
)
{
logger
=
getLogger
();
}
logger
.
warning
(
'
python script has stderr.
'
);
logger
.
warning
(
'
script:
'
,
script
);
logger
.
warning
(
'
stderr:
'
,
stderr
);
}
return
stdout
;
}
ts/nni_manager/common/utils.ts
View file @
277e63f2
...
...
@@ -25,8 +25,7 @@ import { ExperimentManager } from './experimentManager';
import
{
HyperParameters
,
TrainingService
,
TrialJobStatus
}
from
'
./trainingService
'
;
function
getExperimentRootDir
():
string
{
return
getExperimentStartupInfo
()
.
getLogDir
();
return
getExperimentStartupInfo
().
getLogDir
();
}
function
getLogDir
():
string
{
...
...
@@ -34,8 +33,7 @@ function getLogDir(): string {
}
function
getLogLevel
():
string
{
return
getExperimentStartupInfo
()
.
getLogLevel
();
return
getExperimentStartupInfo
().
getLogLevel
();
}
function
getDefaultDatabaseDir
():
string
{
...
...
@@ -481,6 +479,11 @@ async function getFreePort(host: string, start: number, end: number): Promise<nu
}
}
export
function
importModule
(
modulePath
:
string
):
any
{
module
.
paths
.
unshift
(
path
.
dirname
(
modulePath
));
return
require
(
path
.
basename
(
modulePath
));
}
export
{
countFilesRecursively
,
validateFileNameRecursively
,
generateParamFileName
,
getMsgDispatcherCommand
,
getCheckpointDir
,
getExperimentsInfoPath
,
getLogDir
,
getExperimentRootDir
,
getJobCancelStatus
,
getDefaultDatabaseDir
,
getIPV4Address
,
unixPathJoin
,
withLockSync
,
getFreePort
,
isPortOpen
,
...
...
ts/nni_manager/core/nnimanager.ts
View file @
277e63f2
...
...
@@ -445,10 +445,7 @@ class NNIManager implements Manager {
throw
new
Error
(
'
Cannot detect training service platform
'
);
}
if
([
'
remote
'
,
'
pai
'
,
'
aml
'
,
'
hybrid
'
].
includes
(
platform
))
{
const
module_
=
await
import
(
'
../training_service/reusable/routerTrainingService
'
);
return
new
module_
.
RouterTrainingService
(
config
);
}
else
if
(
platform
===
'
local
'
)
{
if
(
platform
===
'
local
'
)
{
const
module_
=
await
import
(
'
../training_service/local/localTrainingService
'
);
return
new
module_
.
LocalTrainingService
(
config
);
}
else
if
(
platform
===
'
kubeflow
'
)
{
...
...
@@ -460,6 +457,9 @@ class NNIManager implements Manager {
}
else
if
(
platform
===
'
adl
'
)
{
const
module_
=
await
import
(
'
../training_service/kubernetes/adl/adlTrainingService
'
);
return
new
module_
.
AdlTrainingService
();
}
else
{
const
module_
=
await
import
(
'
../training_service/reusable/routerTrainingService
'
);
return
await
module_
.
RouterTrainingService
.
construct
(
config
);
}
throw
new
Error
(
`Unsupported training service platform "
${
platform
}
"`
);
...
...
ts/nni_manager/main.ts
View file @
277e63f2
...
...
@@ -83,11 +83,6 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
const
port
:
number
=
parseInt
(
strPort
,
10
);
const
mode
:
string
=
parseArg
([
'
--mode
'
,
'
-m
'
]);
if
(
!
[
'
local
'
,
'
remote
'
,
'
pai
'
,
'
kubeflow
'
,
'
frameworkcontroller
'
,
'
dlts
'
,
'
aml
'
,
'
adl
'
,
'
hybrid
'
].
includes
(
mode
))
{
console
.
log
(
`FATAL: unknown mode:
${
mode
}
`
);
usage
();
process
.
exit
(
1
);
}
const
startMode
:
string
=
parseArg
([
'
--start_mode
'
,
'
-s
'
]);
if
(
!
[
ExperimentStartUpMode
.
NEW
,
ExperimentStartUpMode
.
RESUME
].
includes
(
startMode
))
{
...
...
ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts
View file @
277e63f2
...
...
@@ -6,9 +6,7 @@
import
*
as
fs
from
'
fs
'
;
import
*
as
path
from
'
path
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
getExperimentRootDir
}
from
'
../../../common/utils
'
;
import
{
ExperimentConfig
,
AmlConfig
,
flattenConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
validateCodeDir
}
from
'
../../common/util
'
;
import
{
AMLClient
}
from
'
../aml/amlClient
'
;
...
...
@@ -31,10 +29,10 @@ export class AMLEnvironmentService extends EnvironmentService {
private
experimentRootDir
:
string
;
private
config
:
FlattenAmlConfig
;
constructor
(
config
:
ExperimentConfig
)
{
constructor
(
experimentRootDir
:
string
,
experimentId
:
string
,
config
:
ExperimentConfig
)
{
super
();
this
.
experimentId
=
getE
xperimentId
()
;
this
.
experimentRootDir
=
getE
xperimentRootDir
()
;
this
.
experimentId
=
e
xperimentId
;
this
.
experimentRootDir
=
e
xperimentRootDir
;
this
.
config
=
flattenConfig
(
config
,
'
aml
'
);
validateCodeDir
(
this
.
config
.
trialCodeDirectory
);
}
...
...
ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts
View file @
277e63f2
...
...
@@ -4,20 +4,31 @@ import { LocalEnvironmentService } from './localEnvironmentService';
import
{
RemoteEnvironmentService
}
from
'
./remoteEnvironmentService
'
;
import
{
EnvironmentService
}
from
'
../environment
'
;
import
{
ExperimentConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
getCustomEnvironmentServiceConfig
}
from
'
../../../common/nniConfig
'
;
import
{
getExperimentRootDir
,
importModule
}
from
'
../../../common/utils
'
;
export
class
EnvironmentServiceFactory
{
public
static
createEnvironmentService
(
name
:
string
,
config
:
ExperimentConfig
):
EnvironmentService
{
switch
(
name
)
{
case
'
local
'
:
return
new
LocalEnvironmentService
(
config
);
case
'
remote
'
:
return
new
RemoteEnvironmentService
(
config
);
case
'
aml
'
:
return
new
AMLEnvironmentService
(
config
);
case
'
openpai
'
:
return
new
OpenPaiEnvironmentService
(
config
);
default
:
throw
new
Error
(
`
${
name
}
not supported!`
);
}
export
async
function
createEnvironmentService
(
name
:
string
,
config
:
ExperimentConfig
):
Promise
<
EnvironmentService
>
{
const
expId
=
getExperimentId
();
const
rootDir
=
getExperimentRootDir
();
switch
(
name
)
{
case
'
local
'
:
return
new
LocalEnvironmentService
(
rootDir
,
expId
,
config
);
case
'
remote
'
:
return
new
RemoteEnvironmentService
(
rootDir
,
expId
,
config
);
case
'
aml
'
:
return
new
AMLEnvironmentService
(
rootDir
,
expId
,
config
);
case
'
openpai
'
:
return
new
OpenPaiEnvironmentService
(
rootDir
,
expId
,
config
);
}
const
esConfig
=
await
getCustomEnvironmentServiceConfig
(
name
);
if
(
esConfig
===
null
)
{
throw
new
Error
(
`
${
name
}
is not a supported training service!`
);
}
const
esModule
=
importModule
(
esConfig
.
nodeModulePath
);
const
esClass
=
esModule
[
esConfig
.
nodeClassName
]
as
any
;
return
new
esClass
(
rootDir
,
expId
,
config
);
}
ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts
View file @
277e63f2
...
...
@@ -7,11 +7,10 @@ import * as fs from 'fs';
import
*
as
path
from
'
path
'
;
import
*
as
tkill
from
'
tree-kill
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
ExperimentConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
EnvironmentInformation
,
EnvironmentService
}
from
'
../environment
'
;
import
{
getExperimentRootDir
,
isAlive
,
getNewLine
}
from
'
../../../common/utils
'
;
import
{
isAlive
,
getNewLine
}
from
'
../../../common/utils
'
;
import
{
execMkdir
,
runScript
,
getScriptName
,
execCopydir
}
from
'
../../common/util
'
;
import
{
SharedStorageService
}
from
'
../sharedStorage
'
...
...
@@ -22,10 +21,10 @@ export class LocalEnvironmentService extends EnvironmentService {
private
experimentRootDir
:
string
;
private
experimentId
:
string
;
constructor
(
_config
:
ExperimentConfig
)
{
constructor
(
experimentRootDir
:
string
,
experimentId
:
string
,
_config
:
ExperimentConfig
)
{
super
();
this
.
experimentId
=
getE
xperimentId
()
;
this
.
experimentRootDir
=
getE
xperimentRootDir
()
;
this
.
experimentId
=
e
xperimentId
;
this
.
experimentRootDir
=
e
xperimentRootDir
;
}
public
get
environmentMaintenceLoopInterval
():
number
{
...
...
@@ -110,8 +109,6 @@ export class LocalEnvironmentService extends EnvironmentService {
const
sharedStorageService
=
component
.
get
<
SharedStorageService
>
(
SharedStorageService
);
if
(
environment
.
useSharedStorage
&&
sharedStorageService
.
canLocalMounted
)
{
this
.
experimentRootDir
=
sharedStorageService
.
localWorkingRoot
;
}
else
{
this
.
experimentRootDir
=
getExperimentRootDir
();
}
const
localEnvCodeFolder
:
string
=
path
.
join
(
this
.
experimentRootDir
,
"
envs
"
);
if
(
environment
.
useSharedStorage
&&
!
sharedStorageService
.
canLocalMounted
)
{
...
...
ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts
View file @
277e63f2
...
...
@@ -7,7 +7,6 @@ import * as yaml from 'js-yaml';
import
*
as
request
from
'
request
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
ExperimentConfig
,
OpenpaiConfig
,
flattenConfig
,
toMegaBytes
}
from
'
../../../common/experimentConfig
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
PAIClusterConfig
}
from
'
../../pai/paiConfig
'
;
...
...
@@ -32,9 +31,9 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
private
experimentId
:
string
;
private
config
:
FlattenOpenpaiConfig
;
constructor
(
config
:
ExperimentConfig
)
{
constructor
(
_experimentRootDir
:
string
,
experimentId
:
string
,
config
:
ExperimentConfig
)
{
super
();
this
.
experimentId
=
getE
xperimentId
()
;
this
.
experimentId
=
e
xperimentId
;
this
.
config
=
flattenConfig
(
config
,
'
openpai
'
);
this
.
paiToken
=
this
.
config
.
token
;
this
.
protocol
=
this
.
config
.
host
.
toLowerCase
().
startsWith
(
'
https://
'
)
?
'
https
'
:
'
http
'
;
...
...
ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts
View file @
277e63f2
...
...
@@ -6,10 +6,9 @@
import
*
as
fs
from
'
fs
'
;
import
*
as
path
from
'
path
'
;
import
*
as
component
from
'
../../../common/component
'
;
import
{
getExperimentId
}
from
'
../../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../../common/log
'
;
import
{
EnvironmentInformation
,
EnvironmentService
}
from
'
../environment
'
;
import
{
getExperimentRootDir
,
getLogLevel
}
from
'
../../../common/utils
'
;
import
{
getLogLevel
}
from
'
../../../common/utils
'
;
import
{
ExperimentConfig
,
RemoteConfig
,
RemoteMachineConfig
,
flattenConfig
}
from
'
../../../common/experimentConfig
'
;
import
{
execMkdir
}
from
'
../../common/util
'
;
import
{
ExecutorManager
}
from
'
../../remote_machine/remoteMachineData
'
;
...
...
@@ -33,14 +32,13 @@ export class RemoteEnvironmentService extends EnvironmentService {
private
experimentId
:
string
;
private
config
:
FlattenRemoteConfig
;
constructor
(
config
:
ExperimentConfig
)
{
constructor
(
experimentRootDir
:
string
,
experimentId
:
string
,
config
:
ExperimentConfig
)
{
super
();
this
.
experimentId
=
getE
xperimentId
()
;
this
.
experimentId
=
e
xperimentId
;
this
.
environmentExecutorManagerMap
=
new
Map
<
string
,
ExecutorManager
>
();
this
.
machineExecutorManagerMap
=
new
Map
<
RemoteMachineConfig
,
ExecutorManager
>
();
this
.
remoteMachineMetaOccupiedMap
=
new
Map
<
RemoteMachineConfig
,
boolean
>
();
this
.
experimentRootDir
=
getExperimentRootDir
();
this
.
experimentId
=
getExperimentId
();
this
.
experimentRootDir
=
experimentRootDir
;
this
.
log
=
getLogger
();
this
.
config
=
flattenConfig
(
config
,
'
remote
'
);
...
...
@@ -103,10 +101,10 @@ export class RemoteEnvironmentService extends EnvironmentService {
// Create root working directory after executor is ready
const
nniRootDir
:
string
=
executor
.
joinPath
(
executor
.
getTempPath
(),
'
nni-experiments
'
);
await
executor
.
createFolder
(
executor
.
getRemoteExperimentRootDir
(
getE
xperimentId
()
));
await
executor
.
createFolder
(
executor
.
getRemoteExperimentRootDir
(
this
.
e
xperimentId
));
// the directory to store temp scripts in remote machine
const
remoteGpuScriptCollectorDir
:
string
=
executor
.
getRemoteScriptsPath
(
getE
xperimentId
()
);
const
remoteGpuScriptCollectorDir
:
string
=
executor
.
getRemoteScriptsPath
(
this
.
e
xperimentId
);
// clean up previous result.
await
executor
.
createFolder
(
remoteGpuScriptCollectorDir
,
true
);
...
...
@@ -245,7 +243,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
throw
new
Error
(
`Mount shared storage on remote machine failed.\n ERROR:
${
result
.
stderr
}
`
);
}
}
else
{
this
.
remoteExperimentRootDir
=
executor
.
getRemoteExperimentRootDir
(
getE
xperimentId
()
);
this
.
remoteExperimentRootDir
=
executor
.
getRemoteExperimentRootDir
(
this
.
e
xperimentId
);
}
environment
.
command
=
await
this
.
getScript
(
environment
);
...
...
ts/nni_manager/training_service/reusable/routerTrainingService.ts
View file @
277e63f2
...
...
@@ -3,7 +3,6 @@
'
use strict
'
;
import
*
as
component
from
'
../../common/component
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
ExperimentConfig
,
RemoteConfig
,
OpenpaiConfig
}
from
'
../../common/experimentConfig
'
;
...
...
@@ -18,23 +17,27 @@ import { TrialDispatcher } from './trialDispatcher';
* It's a intermedia implementation to support reusable training service.
* The final goal is to support reusable training job in higher level than training service.
*/
@
component
.
Singleton
class
RouterTrainingService
implements
TrainingService
{
pr
otected
readonly
log
:
Logger
;
private
internalTrainingService
:
TrainingService
;
pr
ivate
log
!
:
Logger
;
private
internalTrainingService
!
:
TrainingService
;
constructor
(
config
:
ExperimentConfig
)
{
this
.
log
=
getLogger
();
public
static
async
construct
(
config
:
ExperimentConfig
):
Promise
<
RouterTrainingService
>
{
const
instance
=
new
RouterTrainingService
();
instance
.
log
=
getLogger
(
'
RouterTrainingService
'
);
const
platform
=
Array
.
isArray
(
config
.
trainingService
)
?
'
hybrid
'
:
config
.
trainingService
.
platform
;
if
(
platform
===
'
remote
'
&&
!
(
<
RemoteConfig
>
config
.
trainingService
).
reuseMode
)
{
this
.
internalTrainingService
=
new
RemoteMachineTrainingService
(
config
);
instance
.
internalTrainingService
=
new
RemoteMachineTrainingService
(
config
);
}
else
if
(
platform
===
'
openpai
'
&&
!
(
<
OpenpaiConfig
>
config
.
trainingService
).
reuseMode
)
{
this
.
internalTrainingService
=
new
PAITrainingService
(
config
);
instance
.
internalTrainingService
=
new
PAITrainingService
(
config
);
}
else
{
this
.
internalTrainingService
=
new
TrialDispatcher
(
config
);
instance
.
internalTrainingService
=
await
TrialDispatcher
.
construct
(
config
);
}
return
instance
;
}
// eslint-disable-next-line @typescript-eslint/no-empty-function
private
constructor
()
{
}
public
async
listTrialJobs
():
Promise
<
TrialJobDetail
[]
>
{
if
(
this
.
internalTrainingService
===
undefined
)
{
throw
new
Error
(
"
TrainingService is not assigned!
"
);
...
...
ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts
View file @
277e63f2
...
...
@@ -203,7 +203,7 @@ describe('Unit Test for TrialDispatcher', () => {
});
beforeEach
(
async
()
=>
{
trialDispatcher
=
new
TrialDispatcher
(
config
);
trialDispatcher
=
await
TrialDispatcher
.
construct
(
config
);
// set ut environment
let
environmentServiceList
:
EnvironmentService
[]
=
[];
...
...
ts/nni_manager/training_service/reusable/trialDispatcher.ts
View file @
277e63f2
...
...
@@ -24,7 +24,7 @@ import { TrialConfig } from '../common/trialConfig';
import
{
validateCodeDir
}
from
'
../common/util
'
;
import
{
Command
,
CommandChannel
}
from
'
./commandChannel
'
;
import
{
EnvironmentInformation
,
EnvironmentService
,
NodeInformation
,
RunnerSettings
,
TrialGpuSummary
}
from
'
./environment
'
;
import
{
EnvironmentService
Factory
}
from
'
./environments/environmentServiceFactory
'
;
import
{
create
EnvironmentService
}
from
'
./environments/environmentServiceFactory
'
;
import
{
GpuScheduler
}
from
'
./gpuScheduler
'
;
import
{
MountedStorageService
}
from
'
./storages/mountedStorageService
'
;
import
{
StorageService
}
from
'
./storageService
'
;
...
...
@@ -39,20 +39,20 @@ import { TrialDetail } from './trial';
**/
@
component
.
Singleton
class
TrialDispatcher
implements
TrainingService
{
private
readonly
log
:
Logger
;
private
readonly
isDeveloping
:
boolean
=
false
;
private
log
:
Logger
;
private
isDeveloping
:
boolean
=
false
;
private
stopping
:
boolean
=
false
;
private
readonly
metricsEmitter
:
EventEmitter
;
private
readonly
experimentId
:
string
;
private
readonly
experimentRootDir
:
string
;
private
metricsEmitter
:
EventEmitter
;
private
experimentId
:
string
;
private
experimentRootDir
:
string
;
private
enableVersionCheck
:
boolean
=
true
;
private
trialConfig
:
TrialConfig
|
undefined
;
private
readonly
trials
:
Map
<
string
,
TrialDetail
>
;
private
readonly
environments
:
Map
<
string
,
EnvironmentInformation
>
;
private
trials
:
Map
<
string
,
TrialDetail
>
;
private
environments
:
Map
<
string
,
EnvironmentInformation
>
;
// make public for ut
public
environmentServiceList
:
EnvironmentService
[]
=
[];
public
commandChannelSet
:
Set
<
CommandChannel
>
;
...
...
@@ -82,8 +82,14 @@ class TrialDispatcher implements TrainingService {
private
config
:
ExperimentConfig
;
constructor
(
config
:
ExperimentConfig
)
{
this
.
log
=
getLogger
();
public
static
async
construct
(
config
:
ExperimentConfig
):
Promise
<
TrialDispatcher
>
{
const
instance
=
new
TrialDispatcher
(
config
);
await
instance
.
asyncConstructor
(
config
);
return
instance
;
}
private
constructor
(
config
:
ExperimentConfig
)
{
this
.
log
=
getLogger
(
'
TrialDispatcher
'
);
this
.
trials
=
new
Map
<
string
,
TrialDetail
>
();
this
.
environments
=
new
Map
<
string
,
EnvironmentInformation
>
();
this
.
metricsEmitter
=
new
EventEmitter
();
...
...
@@ -109,18 +115,14 @@ class TrialDispatcher implements TrainingService {
if
(
this
.
enableGpuScheduler
)
{
this
.
log
.
info
(
`TrialDispatcher: GPU scheduler is enabled.`
)
}
}
validateCodeDir
(
config
.
trialCodeDirectory
);
private
async
asyncConstructor
(
config
:
ExperimentConfig
):
Promise
<
void
>
{
await
validateCodeDir
(
config
.
trialCodeDirectory
);
if
(
Array
.
isArray
(
config
.
trainingService
))
{
config
.
trainingService
.
forEach
(
trainingService
=>
{
const
env
=
EnvironmentServiceFactory
.
createEnvironmentService
(
trainingService
.
platform
,
config
);
this
.
environmentServiceList
.
push
(
env
);
});
}
else
{
const
env
=
EnvironmentServiceFactory
.
createEnvironmentService
(
config
.
trainingService
.
platform
,
config
);
this
.
environmentServiceList
.
push
(
env
);
}
const
serviceConfigs
=
Array
.
isArray
(
config
.
trainingService
)
?
config
.
trainingService
:
[
config
.
trainingService
];
const
servicePromises
=
serviceConfigs
.
map
(
serviceConfig
=>
createEnvironmentService
(
serviceConfig
.
platform
,
config
));
this
.
environmentServiceList
=
await
Promise
.
all
(
servicePromises
);
this
.
environmentMaintenceLoopInterval
=
Math
.
max
(
...
this
.
environmentServiceList
.
map
((
env
)
=>
env
.
environmentMaintenceLoopInterval
)
...
...
@@ -132,7 +134,7 @@ class TrialDispatcher implements TrainingService {
}
if
(
this
.
config
.
sharedStorage
!==
undefined
)
{
this
.
initializeSharedStorage
(
this
.
config
.
sharedStorage
);
await
this
.
initializeSharedStorage
(
this
.
config
.
sharedStorage
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment