Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5ab984a4
"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "b07309d9c63919f2fd601c596fda08114a9da0a6"
Unverified
Commit
5ab984a4
authored
Apr 09, 2021
by
J-shang
Committed by
GitHub
Apr 09, 2021
Browse files
tensorboard backend (#3454)
parent
6808708d
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
500 additions
and
5 deletions
+500
-5
ts/nni_manager/common/manager.ts
ts/nni_manager/common/manager.ts
+3
-0
ts/nni_manager/common/tensorboardManager.ts
ts/nni_manager/common/tensorboardManager.ts
+33
-0
ts/nni_manager/common/trainingService.ts
ts/nni_manager/common/trainingService.ts
+2
-0
ts/nni_manager/common/utils.ts
ts/nni_manager/common/utils.ts
+40
-4
ts/nni_manager/core/nniTensorboardManager.ts
ts/nni_manager/core/nniTensorboardManager.ts
+229
-0
ts/nni_manager/core/nnimanager.ts
ts/nni_manager/core/nnimanager.ts
+10
-0
ts/nni_manager/core/test/mockedTrainingService.ts
ts/nni_manager/core/test/mockedTrainingService.ts
+8
-0
ts/nni_manager/core/test/nnimanager.test.ts
ts/nni_manager/core/test/nnimanager.test.ts
+3
-0
ts/nni_manager/main.ts
ts/nni_manager/main.ts
+5
-0
ts/nni_manager/rest_server/restHandler.ts
ts/nni_manager/rest_server/restHandler.ts
+70
-0
ts/nni_manager/rest_server/test/mockedNNIManager.ts
ts/nni_manager/rest_server/test/mockedNNIManager.ts
+8
-0
ts/nni_manager/rest_server/test/restserver.test.ts
ts/nni_manager/rest_server/test/restserver.test.ts
+4
-1
ts/nni_manager/training_service/dlts/dltsTrainingService.ts
ts/nni_manager/training_service/dlts/dltsTrainingService.ts
+8
-0
ts/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts
.../training_service/kubernetes/kubernetesTrainingService.ts
+8
-0
ts/nni_manager/training_service/local/localTrainingService.ts
...ni_manager/training_service/local/localTrainingService.ts
+16
-0
ts/nni_manager/training_service/pai/paiTrainingService.ts
ts/nni_manager/training_service/pai/paiTrainingService.ts
+8
-0
ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
...ng_service/remote_machine/remoteMachineTrainingService.ts
+8
-0
ts/nni_manager/training_service/reusable/routerTrainingService.ts
...anager/training_service/reusable/routerTrainingService.ts
+14
-0
ts/nni_manager/training_service/reusable/trialDispatcher.ts
ts/nni_manager/training_service/reusable/trialDispatcher.ts
+23
-0
No files found.
ts/nni_manager/common/manager.ts
View file @
5ab984a4
...
@@ -108,6 +108,9 @@ abstract class Manager {
...
@@ -108,6 +108,9 @@ abstract class Manager {
public
abstract
getTrialJobStatistics
():
Promise
<
TrialJobStatistics
[]
>
;
public
abstract
getTrialJobStatistics
():
Promise
<
TrialJobStatistics
[]
>
;
public
abstract
getStatus
():
NNIManagerStatus
;
public
abstract
getStatus
():
NNIManagerStatus
;
public
abstract
getTrialOutputLocalPath
(
trialJobId
:
string
):
Promise
<
string
>
;
public
abstract
fetchTrialOutput
(
trialJobId
:
string
,
subpath
:
string
):
Promise
<
void
>
;
}
}
export
{
Manager
,
ExperimentParams
,
ExperimentProfile
,
TrialJobStatistics
,
ProfileUpdateType
,
NNIManagerStatus
,
ExperimentStatus
,
ExperimentStartUpMode
};
export
{
Manager
,
ExperimentParams
,
ExperimentProfile
,
TrialJobStatistics
,
ProfileUpdateType
,
NNIManagerStatus
,
ExperimentStatus
,
ExperimentStartUpMode
};
ts/nni_manager/common/tensorboardManager.ts
0 → 100644
View file @
5ab984a4
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
interface
TensorboardParams
{
trials
:
string
;
}
type
TensorboardTaskStatus
=
'
RUNNING
'
|
'
DOWNLOADING_DATA
'
|
'
STOPPING
'
|
'
STOPPED
'
|
'
ERROR
'
|
'
FAIL_DOWNLOAD_DATA
'
;
interface
TensorboardTaskInfo
{
readonly
id
:
string
;
readonly
status
:
TensorboardTaskStatus
;
readonly
trialJobIdList
:
string
[];
readonly
trialLogDirectoryList
:
string
[];
readonly
pid
?:
number
;
readonly
port
?:
string
;
}
abstract
class
TensorboardManager
{
public
abstract
startTensorboardTask
(
tensorboardParams
:
TensorboardParams
):
Promise
<
TensorboardTaskInfo
>
;
public
abstract
getTensorboardTask
(
tensorboardTaskId
:
string
):
Promise
<
TensorboardTaskInfo
>
;
public
abstract
updateTensorboardTask
(
tensorboardTaskId
:
string
):
Promise
<
TensorboardTaskInfo
>
;
public
abstract
listTensorboardTasks
():
Promise
<
TensorboardTaskInfo
[]
>
;
public
abstract
stopTensorboardTask
(
tensorboardTaskId
:
string
):
Promise
<
TensorboardTaskInfo
>
;
public
abstract
stopAllTensorboardTask
():
Promise
<
void
>
;
public
abstract
stop
():
Promise
<
void
>
;
}
export
{
TensorboardParams
,
TensorboardTaskStatus
,
TensorboardTaskInfo
,
TensorboardManager
}
ts/nni_manager/common/trainingService.ts
View file @
5ab984a4
...
@@ -85,6 +85,8 @@ abstract class TrainingService {
...
@@ -85,6 +85,8 @@ abstract class TrainingService {
public
abstract
getTrialLog
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
;
public
abstract
getTrialLog
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
;
public
abstract
setClusterMetadata
(
key
:
string
,
value
:
string
):
Promise
<
void
>
;
public
abstract
setClusterMetadata
(
key
:
string
,
value
:
string
):
Promise
<
void
>
;
public
abstract
getClusterMetadata
(
key
:
string
):
Promise
<
string
>
;
public
abstract
getClusterMetadata
(
key
:
string
):
Promise
<
string
>
;
public
abstract
getTrialOutputLocalPath
(
trialJobId
:
string
):
Promise
<
string
>
;
public
abstract
fetchTrialOutput
(
trialJobId
:
string
,
subpath
:
string
):
Promise
<
void
>
;
public
abstract
cleanUp
():
Promise
<
void
>
;
public
abstract
cleanUp
():
Promise
<
void
>
;
public
abstract
run
():
Promise
<
void
>
;
public
abstract
run
():
Promise
<
void
>
;
}
}
...
...
ts/nni_manager/common/utils.ts
View file @
5ab984a4
...
@@ -9,6 +9,7 @@ import * as cpp from 'child-process-promise';
...
@@ -9,6 +9,7 @@ import * as cpp from 'child-process-promise';
import
*
as
cp
from
'
child_process
'
;
import
*
as
cp
from
'
child_process
'
;
import
{
ChildProcess
,
spawn
,
StdioOptions
}
from
'
child_process
'
;
import
{
ChildProcess
,
spawn
,
StdioOptions
}
from
'
child_process
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
net
from
'
net
'
;
import
*
as
os
from
'
os
'
;
import
*
as
os
from
'
os
'
;
import
*
as
path
from
'
path
'
;
import
*
as
path
from
'
path
'
;
import
*
as
lockfile
from
'
lockfile
'
;
import
*
as
lockfile
from
'
lockfile
'
;
...
@@ -340,11 +341,9 @@ async function getVersion(): Promise<string> {
...
@@ -340,11 +341,9 @@ async function getVersion(): Promise<string> {
/**
/**
* run command as ChildProcess
* run command as ChildProcess
*/
*/
function
getTunerProc
(
command
:
string
,
stdio
:
StdioOptions
,
newCwd
:
string
,
newEnv
:
any
):
ChildProcess
{
function
getTunerProc
(
command
:
string
,
stdio
:
StdioOptions
,
newCwd
:
string
,
newEnv
:
any
,
newShell
:
boolean
=
true
,
isDetached
:
boolean
=
false
):
ChildProcess
{
let
cmd
:
string
=
command
;
let
cmd
:
string
=
command
;
let
arg
:
string
[]
=
[];
let
arg
:
string
[]
=
[];
let
newShell
:
boolean
=
true
;
let
isDetached
:
boolean
=
false
;
if
(
process
.
platform
===
"
win32
"
)
{
if
(
process
.
platform
===
"
win32
"
)
{
cmd
=
command
.
split
(
"
"
,
1
)[
0
];
cmd
=
command
.
split
(
"
"
,
1
)[
0
];
arg
=
command
.
substr
(
cmd
.
length
+
1
).
split
(
"
"
);
arg
=
command
.
substr
(
cmd
.
length
+
1
).
split
(
"
"
);
...
@@ -449,8 +448,45 @@ function withLockSync(func: Function, filePath: string, lockOpts: {[key: string]
...
@@ -449,8 +448,45 @@ function withLockSync(func: Function, filePath: string, lockOpts: {[key: string]
return
result
;
return
result
;
}
}
async
function
isPortOpen
(
host
:
string
,
port
:
number
):
Promise
<
boolean
>
{
return
new
Promise
<
boolean
>
((
resolve
,
reject
)
=>
{
try
{
const
stream
=
net
.
createConnection
(
port
,
host
);
const
id
=
setTimeout
(()
=>
{
stream
.
destroy
();
resolve
(
false
);
},
1000
);
stream
.
on
(
'
connect
'
,
()
=>
{
clearTimeout
(
id
);
stream
.
destroy
();
resolve
(
true
);
});
stream
.
on
(
'
error
'
,
()
=>
{
clearTimeout
(
id
);
stream
.
destroy
();
resolve
(
false
);
});
}
catch
(
error
)
{
reject
(
error
);
}
});
}
async
function
getFreePort
(
host
:
string
,
start
:
number
,
end
:
number
):
Promise
<
number
>
{
if
(
start
>
end
)
{
throw
new
Error
(
`no more free port`
);
}
if
(
await
isPortOpen
(
host
,
start
))
{
return
await
getFreePort
(
host
,
start
+
1
,
end
);
}
else
{
return
start
;
}
}
export
{
export
{
countFilesRecursively
,
validateFileNameRecursively
,
generateParamFileName
,
getMsgDispatcherCommand
,
getCheckpointDir
,
getExperimentsInfoPath
,
countFilesRecursively
,
validateFileNameRecursively
,
generateParamFileName
,
getMsgDispatcherCommand
,
getCheckpointDir
,
getExperimentsInfoPath
,
getLogDir
,
getExperimentRootDir
,
getJobCancelStatus
,
getDefaultDatabaseDir
,
getIPV4Address
,
unixPathJoin
,
withLockSync
,
getLogDir
,
getExperimentRootDir
,
getJobCancelStatus
,
getDefaultDatabaseDir
,
getIPV4Address
,
unixPathJoin
,
withLockSync
,
getFreePort
,
isPortOpen
,
mkDirP
,
mkDirPSync
,
delay
,
prepareUnitTest
,
parseArg
,
cleanupUnitTest
,
uniqueString
,
randomInt
,
randomSelect
,
getLogLevel
,
getVersion
,
getCmdPy
,
getTunerProc
,
isAlive
,
killPid
,
getNewLine
mkDirP
,
mkDirPSync
,
delay
,
prepareUnitTest
,
parseArg
,
cleanupUnitTest
,
uniqueString
,
randomInt
,
randomSelect
,
getLogLevel
,
getVersion
,
getCmdPy
,
getTunerProc
,
isAlive
,
killPid
,
getNewLine
};
};
ts/nni_manager/core/nniTensorboardManager.ts
0 → 100644
View file @
5ab984a4
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'
use strict
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
cp
from
'
child_process
'
;
import
*
as
path
from
'
path
'
;
import
{
ChildProcess
}
from
'
child_process
'
;
import
*
as
component
from
'
../common/component
'
;
import
{
getLogger
,
Logger
}
from
'
../common/log
'
;
import
{
getTunerProc
,
isAlive
,
uniqueString
,
mkDirPSync
,
getFreePort
}
from
'
../common/utils
'
;
import
{
Manager
}
from
'
../common/manager
'
;
import
{
TensorboardParams
,
TensorboardTaskStatus
,
TensorboardTaskInfo
,
TensorboardManager
}
from
'
../common/tensorboardManager
'
;
class
TensorboardTaskDetail
implements
TensorboardTaskInfo
{
public
id
:
string
;
public
status
:
TensorboardTaskStatus
;
public
trialJobIdList
:
string
[];
public
trialLogDirectoryList
:
string
[];
public
pid
?:
number
;
public
port
?:
string
;
constructor
(
id
:
string
,
status
:
TensorboardTaskStatus
,
trialJobIdList
:
string
[],
trialLogDirectoryList
:
string
[])
{
this
.
id
=
id
;
this
.
status
=
status
;
this
.
trialJobIdList
=
trialJobIdList
;
this
.
trialLogDirectoryList
=
trialLogDirectoryList
;
}
}
class
NNITensorboardManager
implements
TensorboardManager
{
private
log
:
Logger
;
private
tensorboardTaskMap
:
Map
<
string
,
TensorboardTaskDetail
>
;
private
tensorboardVersion
:
string
|
undefined
;
private
nniManager
:
Manager
;
constructor
()
{
this
.
log
=
getLogger
();
this
.
tensorboardTaskMap
=
new
Map
<
string
,
TensorboardTaskDetail
>
();
this
.
setTensorboardVersion
();
this
.
nniManager
=
component
.
get
(
Manager
);
}
public
async
startTensorboardTask
(
tensorboardParams
:
TensorboardParams
):
Promise
<
TensorboardTaskDetail
>
{
const
trialJobIds
=
tensorboardParams
.
trials
;
const
trialJobIdList
:
string
[]
=
[];
const
trialLogDirectoryList
:
string
[]
=
[];
await
Promise
.
all
(
trialJobIds
.
split
(
'
,
'
).
map
(
async
(
trialJobId
)
=>
{
const
trialTensorboardDataPath
=
path
.
join
(
await
this
.
nniManager
.
getTrialOutputLocalPath
(
trialJobId
),
'
tensorboard
'
);
mkDirPSync
(
trialTensorboardDataPath
);
trialJobIdList
.
push
(
trialJobId
);
trialLogDirectoryList
.
push
(
trialTensorboardDataPath
);
}));
this
.
log
.
info
(
`tensorboard:
${
trialJobIdList
}
${
trialLogDirectoryList
}
`
);
return
await
this
.
startTensorboardTaskProcess
(
trialJobIdList
,
trialLogDirectoryList
);
}
private
async
startTensorboardTaskProcess
(
trialJobIdList
:
string
[],
trialLogDirectoryList
:
string
[]):
Promise
<
TensorboardTaskDetail
>
{
const
host
=
'
localhost
'
;
const
port
=
await
getFreePort
(
host
,
6006
,
65535
);
const
command
=
await
this
.
getTensorboardStartCommand
(
trialJobIdList
,
trialLogDirectoryList
,
port
);
this
.
log
.
info
(
`tensorboard start command:
${
command
}
`
);
const
tensorboardTask
=
new
TensorboardTaskDetail
(
uniqueString
(
5
),
'
RUNNING
'
,
trialJobIdList
,
trialLogDirectoryList
);
this
.
tensorboardTaskMap
.
set
(
tensorboardTask
.
id
,
tensorboardTask
);
const
tensorboardProc
:
ChildProcess
=
getTunerProc
(
command
,
'
ignore
'
,
process
.
cwd
(),
process
.
env
,
true
,
true
);
tensorboardProc
.
on
(
'
error
'
,
async
(
error
)
=>
{
this
.
log
.
error
(
error
);
const
alive
:
boolean
=
await
isAlive
(
tensorboardProc
.
pid
);
if
(
alive
)
{
process
.
kill
(
-
tensorboardProc
.
pid
);
}
this
.
setTensorboardTaskStatus
(
tensorboardTask
,
'
ERROR
'
);
});
tensorboardTask
.
pid
=
tensorboardProc
.
pid
;
tensorboardTask
.
port
=
`
${
port
}
`
;
this
.
log
.
info
(
`tensorboard task id:
${
tensorboardTask
.
id
}
`
);
this
.
updateTensorboardTask
(
tensorboardTask
.
id
);
return
tensorboardTask
;
}
private
async
getTensorboardStartCommand
(
trialJobIdList
:
string
[],
trialLogDirectoryList
:
string
[],
port
:
number
):
Promise
<
string
>
{
if
(
this
.
tensorboardVersion
===
undefined
)
{
this
.
setTensorboardVersion
();
if
(
this
.
tensorboardVersion
===
undefined
)
{
throw
new
Error
(
`Tensorboard may not installed, if you want to use tensorboard, please check if tensorboard installed.`
);
}
}
if
(
trialJobIdList
.
length
!==
trialLogDirectoryList
.
length
)
{
throw
new
Error
(
'
trial list length does not match
'
);
}
if
(
trialJobIdList
.
length
===
0
)
{
throw
new
Error
(
'
trial list length is 0
'
);
}
let
logdirCmd
=
'
--logdir
'
;
if
(
this
.
tensorboardVersion
>=
'
2.0
'
)
{
logdirCmd
=
'
--bind_all --logdir_spec
'
}
try
{
const
logRealPaths
:
string
[]
=
[];
for
(
const
idx
in
trialJobIdList
)
{
const
realPath
=
fs
.
realpathSync
(
trialLogDirectoryList
[
idx
]);
const
trialJob
=
await
this
.
nniManager
.
getTrialJob
(
trialJobIdList
[
idx
]);
logRealPaths
.
push
(
`
${
trialJob
.
sequenceId
}
-
${
trialJobIdList
[
idx
]}
:
${
realPath
}
`
);
}
const
command
=
`tensorboard
${
logdirCmd
}
=
${
logRealPaths
.
join
(
'
,
'
)}
--port=
${
port
}
`
;
return
command
;
}
catch
(
error
){
throw
new
Error
(
`
${
error
.
message
}
`
);
}
}
private
setTensorboardVersion
():
void
{
let
command
=
`python3 -c 'import tensorboard ; print(tensorboard.__version__)'`
;
if
(
process
.
platform
===
'
win32
'
)
{
command
=
`python -c 'import tensorboard ; print(tensorboard.__version__)'`
;
}
try
{
const
tensorboardVersion
=
cp
.
execSync
(
command
).
toString
();
if
(
/
\d
+
(
.
\d
+
)
*/
.
test
(
tensorboardVersion
))
{
this
.
tensorboardVersion
=
tensorboardVersion
;
}
}
catch
(
error
)
{
this
.
log
.
warning
(
`Tensorboard may not installed, if you want to use tensorboard, please check if tensorboard installed.`
);
}
}
public
async
getTensorboardTask
(
tensorboardTaskId
:
string
):
Promise
<
TensorboardTaskDetail
>
{
const
tensorboardTask
:
TensorboardTaskDetail
|
undefined
=
this
.
tensorboardTaskMap
.
get
(
tensorboardTaskId
);
if
(
tensorboardTask
===
undefined
)
{
throw
new
Error
(
'
Tensorboard task not found
'
);
}
else
{
if
(
tensorboardTask
.
status
!==
'
STOPPED
'
){
const
alive
:
boolean
=
await
isAlive
(
tensorboardTask
.
pid
);
if
(
!
alive
)
{
this
.
setTensorboardTaskStatus
(
tensorboardTask
,
'
ERROR
'
);
}
}
return
tensorboardTask
;
}
}
public
async
listTensorboardTasks
():
Promise
<
TensorboardTaskDetail
[]
>
{
const
result
:
TensorboardTaskDetail
[]
=
[];
this
.
tensorboardTaskMap
.
forEach
((
value
)
=>
{
result
.
push
(
value
);
});
return
result
;
}
private
setTensorboardTaskStatus
(
tensorboardTask
:
TensorboardTaskDetail
,
newStatus
:
TensorboardTaskStatus
):
void
{
if
(
tensorboardTask
.
status
!==
newStatus
)
{
const
oldStatus
=
tensorboardTask
.
status
;
tensorboardTask
.
status
=
newStatus
;
this
.
log
.
info
(
`tensorboardTask
${
tensorboardTask
.
id
}
status update:
${
oldStatus
}
to
${
tensorboardTask
.
status
}
`
);
}
}
private
downloadDataFinished
(
tensorboardTask
:
TensorboardTaskDetail
):
void
{
this
.
setTensorboardTaskStatus
(
tensorboardTask
,
'
RUNNING
'
);
}
public
async
updateTensorboardTask
(
tensorboardTaskId
:
string
):
Promise
<
TensorboardTaskInfo
>
{
const
tensorboardTask
:
TensorboardTaskDetail
=
await
this
.
getTensorboardTask
(
tensorboardTaskId
);
if
([
'
RUNNING
'
,
'
FAIL_DOWNLOAD_DATA
'
].
includes
(
tensorboardTask
.
status
)){
this
.
setTensorboardTaskStatus
(
tensorboardTask
,
'
DOWNLOADING_DATA
'
);
Promise
.
all
(
tensorboardTask
.
trialJobIdList
.
map
((
trialJobId
)
=>
{
this
.
nniManager
.
fetchTrialOutput
(
trialJobId
,
'
tensorboard
'
);
})).
then
(()
=>
{
this
.
downloadDataFinished
(
tensorboardTask
);
}).
catch
((
error
:
Error
)
=>
{
this
.
setTensorboardTaskStatus
(
tensorboardTask
,
'
FAIL_DOWNLOAD_DATA
'
);
this
.
log
.
error
(
`
${
error
.
message
}
`
);
});
return
tensorboardTask
;
}
else
{
throw
new
Error
(
'
only tensorboard task with RUNNING or FAIL_DOWNLOAD_DATA can update data
'
);
}
}
public
async
stopTensorboardTask
(
tensorboardTaskId
:
string
):
Promise
<
TensorboardTaskInfo
>
{
const
tensorboardTask
=
await
this
.
getTensorboardTask
(
tensorboardTaskId
);
if
([
'
RUNNING
'
,
'
FAIL_DOWNLOAD_DATA
'
].
includes
(
tensorboardTask
.
status
)){
this
.
killTensorboardTaskProc
(
tensorboardTask
);
return
tensorboardTask
;
}
else
{
throw
new
Error
(
'
Only RUNNING FAIL_DOWNLOAD_DATA task can be stopped
'
);
}
}
private
async
killTensorboardTaskProc
(
tensorboardTask
:
TensorboardTaskDetail
):
Promise
<
void
>
{
if
([
'
ERROR
'
,
'
STOPPED
'
].
includes
(
tensorboardTask
.
status
))
{
return
}
const
alive
:
boolean
=
await
isAlive
(
tensorboardTask
.
pid
);
if
(
!
alive
)
{
this
.
setTensorboardTaskStatus
(
tensorboardTask
,
'
ERROR
'
);
}
else
{
this
.
setTensorboardTaskStatus
(
tensorboardTask
,
'
STOPPING
'
);
if
(
tensorboardTask
.
pid
)
{
process
.
kill
(
-
tensorboardTask
.
pid
);
}
this
.
log
.
debug
(
`Tensorboard task
${
tensorboardTask
.
id
}
stopped.`
);
this
.
setTensorboardTaskStatus
(
tensorboardTask
,
'
STOPPED
'
);
this
.
tensorboardTaskMap
.
delete
(
tensorboardTask
.
id
);
}
}
public
async
stopAllTensorboardTask
():
Promise
<
void
>
{
this
.
log
.
info
(
'
Forced stopping all tensorboard task.
'
)
for
(
const
task
of
this
.
tensorboardTaskMap
)
{
await
this
.
killTensorboardTaskProc
(
task
[
1
]);
}
this
.
log
.
info
(
'
All tensorboard task stopped.
'
)
}
public
async
stop
():
Promise
<
void
>
{
await
this
.
stopAllTensorboardTask
();
this
.
log
.
info
(
'
Tensorboard manager stopped.
'
);
}
}
export
{
NNITensorboardManager
,
TensorboardTaskDetail
};
ts/nni_manager/core/nnimanager.ts
View file @
5ab984a4
...
@@ -16,6 +16,7 @@ import {
...
@@ -16,6 +16,7 @@ import {
NNIManagerStatus
,
ProfileUpdateType
,
TrialJobStatistics
NNIManagerStatus
,
ProfileUpdateType
,
TrialJobStatistics
}
from
'
../common/manager
'
;
}
from
'
../common/manager
'
;
import
{
ExperimentManager
}
from
'
../common/experimentManager
'
;
import
{
ExperimentManager
}
from
'
../common/experimentManager
'
;
import
{
TensorboardManager
}
from
'
../common/tensorboardManager
'
;
import
{
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
TrialJobStatus
,
LogType
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
TrialJobStatus
,
LogType
}
from
'
../common/trainingService
'
;
}
from
'
../common/trainingService
'
;
...
@@ -356,6 +357,7 @@ class NNIManager implements Manager {
...
@@ -356,6 +357,7 @@ class NNIManager implements Manager {
let
hasError
:
boolean
=
false
;
let
hasError
:
boolean
=
false
;
try
{
try
{
await
this
.
experimentManager
.
stop
();
await
this
.
experimentManager
.
stop
();
await
component
.
get
<
TensorboardManager
>
(
TensorboardManager
).
stop
();
await
this
.
dataStore
.
close
();
await
this
.
dataStore
.
close
();
await
component
.
get
<
NNIRestServer
>
(
NNIRestServer
).
stop
();
await
component
.
get
<
NNIRestServer
>
(
NNIRestServer
).
stop
();
}
catch
(
err
)
{
}
catch
(
err
)
{
...
@@ -881,6 +883,14 @@ class NNIManager implements Manager {
...
@@ -881,6 +883,14 @@ class NNIManager implements Manager {
return
Promise
.
resolve
(
chkpDir
);
return
Promise
.
resolve
(
chkpDir
);
}
}
public
async
getTrialOutputLocalPath
(
trialJobId
:
string
):
Promise
<
string
>
{
return
this
.
trainingService
.
getTrialOutputLocalPath
(
trialJobId
);
}
public
async
fetchTrialOutput
(
trialJobId
:
string
,
subpath
:
string
):
Promise
<
void
>
{
return
this
.
trainingService
.
fetchTrialOutput
(
trialJobId
,
subpath
);
}
}
}
export
{
NNIManager
};
export
{
NNIManager
};
ts/nni_manager/core/test/mockedTrainingService.ts
View file @
5ab984a4
...
@@ -124,6 +124,14 @@ class MockedTrainingService extends TrainingService {
...
@@ -124,6 +124,14 @@ class MockedTrainingService extends TrainingService {
public
cleanUp
():
Promise
<
void
>
{
public
cleanUp
():
Promise
<
void
>
{
return
Promise
.
resolve
();
return
Promise
.
resolve
();
}
}
public
getTrialOutputLocalPath
(
_trialJobId
:
string
):
Promise
<
string
>
{
throw
new
MethodNotImplementedError
();
}
public
fetchTrialOutput
(
_trialJobId
:
string
,
_subpath
:
string
):
Promise
<
void
>
{
throw
new
MethodNotImplementedError
();
}
}
}
export
{
MockedTrainingService
,
testTrainingServiceProvider
}
export
{
MockedTrainingService
,
testTrainingServiceProvider
}
ts/nni_manager/core/test/nnimanager.test.ts
View file @
5ab984a4
...
@@ -19,6 +19,8 @@ import { NNIManager } from '../nnimanager';
...
@@ -19,6 +19,8 @@ import { NNIManager } from '../nnimanager';
import
{
SqlDB
}
from
'
../sqlDatabase
'
;
import
{
SqlDB
}
from
'
../sqlDatabase
'
;
import
{
MockedTrainingService
}
from
'
./mockedTrainingService
'
;
import
{
MockedTrainingService
}
from
'
./mockedTrainingService
'
;
import
{
MockedDataStore
}
from
'
./mockedDatastore
'
;
import
{
MockedDataStore
}
from
'
./mockedDatastore
'
;
import
{
TensorboardManager
}
from
'
../../common/tensorboardManager
'
;
import
{
NNITensorboardManager
}
from
'
../../core/nniTensorboardManager
'
;
import
*
as
path
from
'
path
'
;
import
*
as
path
from
'
path
'
;
async
function
initContainer
():
Promise
<
void
>
{
async
function
initContainer
():
Promise
<
void
>
{
...
@@ -28,6 +30,7 @@ async function initContainer(): Promise<void> {
...
@@ -28,6 +30,7 @@ async function initContainer(): Promise<void> {
Container
.
bind
(
Database
).
to
(
SqlDB
).
scope
(
Scope
.
Singleton
);
Container
.
bind
(
Database
).
to
(
SqlDB
).
scope
(
Scope
.
Singleton
);
Container
.
bind
(
DataStore
).
to
(
MockedDataStore
).
scope
(
Scope
.
Singleton
);
Container
.
bind
(
DataStore
).
to
(
MockedDataStore
).
scope
(
Scope
.
Singleton
);
Container
.
bind
(
ExperimentManager
).
to
(
NNIExperimentsManager
).
scope
(
Scope
.
Singleton
);
Container
.
bind
(
ExperimentManager
).
to
(
NNIExperimentsManager
).
scope
(
Scope
.
Singleton
);
Container
.
bind
(
TensorboardManager
).
to
(
NNITensorboardManager
).
scope
(
Scope
.
Singleton
);
await
component
.
get
<
DataStore
>
(
DataStore
).
init
();
await
component
.
get
<
DataStore
>
(
DataStore
).
init
();
}
}
...
...
ts/nni_manager/main.ts
View file @
5ab984a4
...
@@ -13,12 +13,14 @@ import { setExperimentStartupInfo } from './common/experimentStartupInfo';
...
@@ -13,12 +13,14 @@ import { setExperimentStartupInfo } from './common/experimentStartupInfo';
import
{
getLogger
,
Logger
,
logLevelNameMap
}
from
'
./common/log
'
;
import
{
getLogger
,
Logger
,
logLevelNameMap
}
from
'
./common/log
'
;
import
{
Manager
,
ExperimentStartUpMode
}
from
'
./common/manager
'
;
import
{
Manager
,
ExperimentStartUpMode
}
from
'
./common/manager
'
;
import
{
ExperimentManager
}
from
'
./common/experimentManager
'
;
import
{
ExperimentManager
}
from
'
./common/experimentManager
'
;
import
{
TensorboardManager
}
from
'
./common/tensorboardManager
'
;
import
{
TrainingService
}
from
'
./common/trainingService
'
;
import
{
TrainingService
}
from
'
./common/trainingService
'
;
import
{
getLogDir
,
mkDirP
,
parseArg
}
from
'
./common/utils
'
;
import
{
getLogDir
,
mkDirP
,
parseArg
}
from
'
./common/utils
'
;
import
{
NNIDataStore
}
from
'
./core/nniDataStore
'
;
import
{
NNIDataStore
}
from
'
./core/nniDataStore
'
;
import
{
NNIManager
}
from
'
./core/nnimanager
'
;
import
{
NNIManager
}
from
'
./core/nnimanager
'
;
import
{
SqlDB
}
from
'
./core/sqlDatabase
'
;
import
{
SqlDB
}
from
'
./core/sqlDatabase
'
;
import
{
NNIExperimentsManager
}
from
'
./core/nniExperimentsManager
'
;
import
{
NNIExperimentsManager
}
from
'
./core/nniExperimentsManager
'
;
import
{
NNITensorboardManager
}
from
'
./core/nniTensorboardManager
'
;
import
{
NNIRestServer
}
from
'
./rest_server/nniRestServer
'
;
import
{
NNIRestServer
}
from
'
./rest_server/nniRestServer
'
;
import
{
FrameworkControllerTrainingService
}
from
'
./training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService
'
;
import
{
FrameworkControllerTrainingService
}
from
'
./training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService
'
;
import
{
AdlTrainingService
}
from
'
./training_service/kubernetes/adl/adlTrainingService
'
;
import
{
AdlTrainingService
}
from
'
./training_service/kubernetes/adl/adlTrainingService
'
;
...
@@ -76,6 +78,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
...
@@ -76,6 +78,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
Container
.
bind
(
ExperimentManager
)
Container
.
bind
(
ExperimentManager
)
.
to
(
NNIExperimentsManager
)
.
to
(
NNIExperimentsManager
)
.
scope
(
Scope
.
Singleton
);
.
scope
(
Scope
.
Singleton
);
Container
.
bind
(
TensorboardManager
)
.
to
(
NNITensorboardManager
)
.
scope
(
Scope
.
Singleton
);
const
DEFAULT_LOGFILE
:
string
=
path
.
join
(
getLogDir
(),
'
nnimanager.log
'
);
const
DEFAULT_LOGFILE
:
string
=
path
.
join
(
getLogDir
(),
'
nnimanager.log
'
);
if
(
foreground
)
{
if
(
foreground
)
{
logFileName
=
undefined
;
logFileName
=
undefined
;
...
...
ts/nni_manager/rest_server/restHandler.ts
View file @
5ab984a4
...
@@ -13,6 +13,7 @@ import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo';
...
@@ -13,6 +13,7 @@ import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo';
import
{
getLogger
,
Logger
}
from
'
../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../common/log
'
;
import
{
ExperimentProfile
,
Manager
,
TrialJobStatistics
}
from
'
../common/manager
'
;
import
{
ExperimentProfile
,
Manager
,
TrialJobStatistics
}
from
'
../common/manager
'
;
import
{
ExperimentManager
}
from
'
../common/experimentManager
'
;
import
{
ExperimentManager
}
from
'
../common/experimentManager
'
;
import
{
TensorboardManager
,
TensorboardTaskInfo
}
from
'
../common/tensorboardManager
'
;
import
{
ValidationSchemas
}
from
'
./restValidationSchemas
'
;
import
{
ValidationSchemas
}
from
'
./restValidationSchemas
'
;
import
{
NNIRestServer
}
from
'
./nniRestServer
'
;
import
{
NNIRestServer
}
from
'
./nniRestServer
'
;
import
{
getVersion
}
from
'
../common/utils
'
;
import
{
getVersion
}
from
'
../common/utils
'
;
...
@@ -23,11 +24,13 @@ class NNIRestHandler {
...
@@ -23,11 +24,13 @@ class NNIRestHandler {
private
restServer
:
NNIRestServer
;
private
restServer
:
NNIRestServer
;
private
nniManager
:
Manager
;
private
nniManager
:
Manager
;
private
experimentsManager
:
ExperimentManager
;
private
experimentsManager
:
ExperimentManager
;
private
tensorboardManager
:
TensorboardManager
;
private
log
:
Logger
;
private
log
:
Logger
;
constructor
(
rs
:
NNIRestServer
)
{
constructor
(
rs
:
NNIRestServer
)
{
this
.
nniManager
=
component
.
get
(
Manager
);
this
.
nniManager
=
component
.
get
(
Manager
);
this
.
experimentsManager
=
component
.
get
(
ExperimentManager
);
this
.
experimentsManager
=
component
.
get
(
ExperimentManager
);
this
.
tensorboardManager
=
component
.
get
(
TensorboardManager
);
this
.
restServer
=
rs
;
this
.
restServer
=
rs
;
this
.
log
=
getLogger
();
this
.
log
=
getLogger
();
}
}
...
@@ -64,6 +67,12 @@ class NNIRestHandler {
...
@@ -64,6 +67,12 @@ class NNIRestHandler {
this
.
getTrialLog
(
router
);
this
.
getTrialLog
(
router
);
this
.
exportData
(
router
);
this
.
exportData
(
router
);
this
.
getExperimentsInfo
(
router
);
this
.
getExperimentsInfo
(
router
);
this
.
startTensorboardTask
(
router
);
this
.
getTensorboardTask
(
router
);
this
.
updateTensorboardTask
(
router
);
this
.
stopTensorboardTask
(
router
);
this
.
stopAllTensorboardTask
(
router
);
this
.
listTensorboardTask
(
router
);
this
.
stop
(
router
);
this
.
stop
(
router
);
// Express-joi-validator configuration
// Express-joi-validator configuration
...
@@ -318,6 +327,67 @@ class NNIRestHandler {
...
@@ -318,6 +327,67 @@ class NNIRestHandler {
});
});
}
}
private
startTensorboardTask
(
router
:
Router
):
void
{
router
.
post
(
'
/tensorboard
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
tensorboardManager
.
startTensorboardTask
(
req
.
body
).
then
((
taskDetail
:
TensorboardTaskInfo
)
=>
{
this
.
log
.
info
(
taskDetail
);
res
.
send
(
Object
.
assign
({},
taskDetail
));
}).
catch
((
err
:
Error
)
=>
{
this
.
handleError
(
err
,
res
,
false
,
400
);
});
});
}
private
getTensorboardTask
(
router
:
Router
):
void
{
router
.
get
(
'
/tensorboard/:id
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
tensorboardManager
.
getTensorboardTask
(
req
.
params
.
id
).
then
((
taskDetail
:
TensorboardTaskInfo
)
=>
{
res
.
send
(
Object
.
assign
({},
taskDetail
));
}).
catch
((
err
:
Error
)
=>
{
this
.
handleError
(
err
,
res
);
});
});
}
private
updateTensorboardTask
(
router
:
Router
):
void
{
router
.
put
(
'
/tensorboard/:id
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
tensorboardManager
.
updateTensorboardTask
(
req
.
params
.
id
).
then
((
taskDetail
:
TensorboardTaskInfo
)
=>
{
res
.
send
(
Object
.
assign
({},
taskDetail
));
}).
catch
((
err
:
Error
)
=>
{
this
.
handleError
(
err
,
res
);
});
});
}
private
stopTensorboardTask
(
router
:
Router
):
void
{
router
.
delete
(
'
/tensorboard/:id
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
tensorboardManager
.
stopTensorboardTask
(
req
.
params
.
id
).
then
((
taskDetail
:
TensorboardTaskInfo
)
=>
{
res
.
send
(
Object
.
assign
({},
taskDetail
));
}).
catch
((
err
:
Error
)
=>
{
this
.
handleError
(
err
,
res
);
});
});
}
private
stopAllTensorboardTask
(
router
:
Router
):
void
{
router
.
delete
(
'
/tensorboard-tasks
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
tensorboardManager
.
stopAllTensorboardTask
().
then
(()
=>
{
res
.
send
();
}).
catch
((
err
:
Error
)
=>
{
this
.
handleError
(
err
,
res
);
});
});
}
private
listTensorboardTask
(
router
:
Router
):
void
{
router
.
get
(
'
/tensorboard-tasks
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
tensorboardManager
.
listTensorboardTasks
().
then
((
taskDetails
:
TensorboardTaskInfo
[])
=>
{
res
.
send
(
taskDetails
);
}).
catch
((
err
:
Error
)
=>
{
this
.
handleError
(
err
,
res
);
});
});
}
private
stop
(
router
:
Router
):
void
{
private
stop
(
router
:
Router
):
void
{
router
.
delete
(
'
/experiment
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
router
.
delete
(
'
/experiment
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
nniManager
.
stopExperimentTopHalf
().
then
(()
=>
{
this
.
nniManager
.
stopExperimentTopHalf
().
then
(()
=>
{
...
...
ts/nni_manager/rest_server/test/mockedNNIManager.ts
View file @
5ab984a4
...
@@ -189,4 +189,12 @@ export class MockedNNIManager extends Manager {
...
@@ -189,4 +189,12 @@ export class MockedNNIManager extends Manager {
return
Promise
.
resolve
([
job1
,
job2
]);
return
Promise
.
resolve
([
job1
,
job2
]);
}
}
public
async
getTrialOutputLocalPath
(
_trialJobId
:
string
):
Promise
<
string
>
{
throw
new
MethodNotImplementedError
();
}
public
async
fetchTrialOutput
(
_trialJobId
:
string
,
_subpath
:
string
):
Promise
<
void
>
{
throw
new
MethodNotImplementedError
();
}
}
}
ts/nni_manager/rest_server/test/restserver.test.ts
View file @
5ab984a4
...
@@ -18,6 +18,8 @@ import { MockedTrainingService } from '../../core/test/mockedTrainingService';
...
@@ -18,6 +18,8 @@ import { MockedTrainingService } from '../../core/test/mockedTrainingService';
import
{
NNIRestServer
}
from
'
../nniRestServer
'
;
import
{
NNIRestServer
}
from
'
../nniRestServer
'
;
import
{
testManagerProvider
}
from
'
./mockedNNIManager
'
;
import
{
testManagerProvider
}
from
'
./mockedNNIManager
'
;
import
{
testExperimentManagerProvider
}
from
'
./mockedExperimentManager
'
;
import
{
testExperimentManagerProvider
}
from
'
./mockedExperimentManager
'
;
import
{
TensorboardManager
}
from
'
../../common/tensorboardManager
'
;
import
{
NNITensorboardManager
}
from
'
../../core/nniTensorboardManager
'
;
describe
(
'
Unit test for rest server
'
,
()
=>
{
describe
(
'
Unit test for rest server
'
,
()
=>
{
...
@@ -28,7 +30,8 @@ describe('Unit test for rest server', () => {
...
@@ -28,7 +30,8 @@ describe('Unit test for rest server', () => {
Container
.
bind
(
Manager
).
provider
(
testManagerProvider
);
Container
.
bind
(
Manager
).
provider
(
testManagerProvider
);
Container
.
bind
(
DataStore
).
to
(
MockedDataStore
);
Container
.
bind
(
DataStore
).
to
(
MockedDataStore
);
Container
.
bind
(
TrainingService
).
to
(
MockedTrainingService
);
Container
.
bind
(
TrainingService
).
to
(
MockedTrainingService
);
Container
.
bind
(
ExperimentManager
).
provider
(
testExperimentManagerProvider
)
Container
.
bind
(
ExperimentManager
).
provider
(
testExperimentManagerProvider
);
Container
.
bind
(
TensorboardManager
).
to
(
NNITensorboardManager
);
const
restServer
:
NNIRestServer
=
component
.
get
(
NNIRestServer
);
const
restServer
:
NNIRestServer
=
component
.
get
(
NNIRestServer
);
restServer
.
start
().
then
(()
=>
{
restServer
.
start
().
then
(()
=>
{
ROOT_URL
=
`
${
restServer
.
endPoint
}
/api/v1/nni`
;
ROOT_URL
=
`
${
restServer
.
endPoint
}
/api/v1/nni`
;
...
...
ts/nni_manager/training_service/dlts/dltsTrainingService.ts
View file @
5ab984a4
...
@@ -565,6 +565,14 @@ class DLTSTrainingService implements TrainingService {
...
@@ -565,6 +565,14 @@ class DLTSTrainingService implements TrainingService {
public
get
isMultiPhaseJobSupported
():
boolean
{
public
get
isMultiPhaseJobSupported
():
boolean
{
return
false
;
return
false
;
}
}
public
getTrialOutputLocalPath
(
_trialJobId
:
string
):
Promise
<
string
>
{
throw
new
MethodNotImplementedError
();
}
public
fetchTrialOutput
(
_trialJobId
:
string
,
_subpath
:
string
):
Promise
<
void
>
{
throw
new
MethodNotImplementedError
();
}
}
}
export
{
DLTSTrainingService
};
export
{
DLTSTrainingService
};
ts/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts
View file @
5ab984a4
...
@@ -393,5 +393,13 @@ abstract class KubernetesTrainingService {
...
@@ -393,5 +393,13 @@ abstract class KubernetesTrainingService {
}
}
return
Promise
.
resolve
(
folderUriInAzure
);
return
Promise
.
resolve
(
folderUriInAzure
);
}
}
public
getTrialOutputLocalPath
(
_trialJobId
:
string
):
Promise
<
string
>
{
throw
new
MethodNotImplementedError
();
}
public
fetchTrialOutput
(
_trialJobId
:
string
,
_subpath
:
string
):
Promise
<
void
>
{
throw
new
MethodNotImplementedError
();
}
}
}
export
{
KubernetesTrainingService
};
export
{
KubernetesTrainingService
};
ts/nni_manager/training_service/local/localTrainingService.ts
View file @
5ab984a4
...
@@ -583,6 +583,22 @@ class LocalTrainingService implements TrainingService {
...
@@ -583,6 +583,22 @@ class LocalTrainingService implements TrainingService {
const
filepath
:
string
=
path
.
join
(
directory
,
generateParamFileName
(
hyperParameters
));
const
filepath
:
string
=
path
.
join
(
directory
,
generateParamFileName
(
hyperParameters
));
await
fs
.
promises
.
writeFile
(
filepath
,
hyperParameters
.
value
,
{
encoding
:
'
utf8
'
});
await
fs
.
promises
.
writeFile
(
filepath
,
hyperParameters
.
value
,
{
encoding
:
'
utf8
'
});
}
}
public
async
getTrialOutputLocalPath
(
trialJobId
:
string
):
Promise
<
string
>
{
return
Promise
.
resolve
(
path
.
join
(
this
.
rootDir
,
'
trials
'
,
trialJobId
));
}
public
async
fetchTrialOutput
(
trialJobId
:
string
,
subpath
:
string
):
Promise
<
void
>
{
let
trialLocalPath
=
await
this
.
getTrialOutputLocalPath
(
trialJobId
);
if
(
subpath
!==
undefined
)
{
trialLocalPath
=
path
.
join
(
trialLocalPath
,
subpath
);
}
if
(
fs
.
existsSync
(
trialLocalPath
))
{
return
Promise
.
resolve
();
}
else
{
return
Promise
.
reject
(
new
Error
(
'
Trial local path not exist.
'
));
}
}
}
}
export
{
LocalTrainingService
};
export
{
LocalTrainingService
};
ts/nni_manager/training_service/pai/paiTrainingService.ts
View file @
5ab984a4
...
@@ -576,6 +576,14 @@ class PAITrainingService implements TrainingService {
...
@@ -576,6 +576,14 @@ class PAITrainingService implements TrainingService {
const
filepath
:
string
=
path
.
join
(
directory
,
generateParamFileName
(
hyperParameters
));
const
filepath
:
string
=
path
.
join
(
directory
,
generateParamFileName
(
hyperParameters
));
await
fs
.
promises
.
writeFile
(
filepath
,
hyperParameters
.
value
,
{
encoding
:
'
utf8
'
});
await
fs
.
promises
.
writeFile
(
filepath
,
hyperParameters
.
value
,
{
encoding
:
'
utf8
'
});
}
}
public
getTrialOutputLocalPath
(
_trialJobId
:
string
):
Promise
<
string
>
{
throw
new
MethodNotImplementedError
();
}
public
fetchTrialOutput
(
_trialJobId
:
string
,
_subpath
:
string
):
Promise
<
void
>
{
throw
new
MethodNotImplementedError
();
}
}
}
export
{
PAITrainingService
};
export
{
PAITrainingService
};
ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
View file @
5ab984a4
...
@@ -679,6 +679,14 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -679,6 +679,14 @@ class RemoteMachineTrainingService implements TrainingService {
await
executor
.
copyFileToRemote
(
localFilepath
,
executor
.
joinPath
(
trialWorkingFolder
,
fileName
));
await
executor
.
copyFileToRemote
(
localFilepath
,
executor
.
joinPath
(
trialWorkingFolder
,
fileName
));
}
}
public
getTrialOutputLocalPath
(
_trialJobId
:
string
):
Promise
<
string
>
{
throw
new
MethodNotImplementedError
();
}
public
fetchTrialOutput
(
_trialJobId
:
string
,
_subpath
:
string
):
Promise
<
void
>
{
throw
new
MethodNotImplementedError
();
}
}
}
export
{
RemoteMachineTrainingService
};
export
{
RemoteMachineTrainingService
};
ts/nni_manager/training_service/reusable/routerTrainingService.ts
View file @
5ab984a4
...
@@ -183,6 +183,20 @@ class RouterTrainingService implements TrainingService {
...
@@ -183,6 +183,20 @@ class RouterTrainingService implements TrainingService {
}
}
return
await
this
.
internalTrainingService
.
run
();
return
await
this
.
internalTrainingService
.
run
();
}
}
public
async
getTrialOutputLocalPath
(
trialJobId
:
string
):
Promise
<
string
>
{
if
(
this
.
internalTrainingService
===
undefined
)
{
throw
new
Error
(
"
TrainingService is not assigned!
"
);
}
return
this
.
internalTrainingService
.
getTrialOutputLocalPath
(
trialJobId
);
}
public
async
fetchTrialOutput
(
trialJobId
:
string
,
subpath
:
string
):
Promise
<
void
>
{
if
(
this
.
internalTrainingService
===
undefined
)
{
throw
new
Error
(
"
TrainingService is not assigned!
"
);
}
return
this
.
internalTrainingService
.
fetchTrialOutput
(
trialJobId
,
subpath
);
}
}
}
export
{
RouterTrainingService
};
export
{
RouterTrainingService
};
ts/nni_manager/training_service/reusable/trialDispatcher.ts
View file @
5ab984a4
...
@@ -941,6 +941,29 @@ class TrialDispatcher implements TrainingService {
...
@@ -941,6 +941,29 @@ class TrialDispatcher implements TrainingService {
this
.
useSharedStorage
=
true
;
this
.
useSharedStorage
=
true
;
return
Promise
.
resolve
();
return
Promise
.
resolve
();
}
}
public
async
getTrialOutputLocalPath
(
trialJobId
:
string
):
Promise
<
string
>
{
// TODO: support non shared storage
if
(
this
.
useSharedStorage
)
{
const
localWorkingRoot
=
component
.
get
<
SharedStorageService
>
(
SharedStorageService
).
localWorkingRoot
;
return
Promise
.
resolve
(
path
.
join
(
localWorkingRoot
,
'
trials
'
,
trialJobId
));
}
else
{
return
Promise
.
reject
(
new
Error
(
'
Only support shared storage right now.
'
));
}
}
public
async
fetchTrialOutput
(
trialJobId
:
string
,
subpath
:
string
|
undefined
):
Promise
<
void
>
{
// TODO: support non shared storage
let
trialLocalPath
=
await
this
.
getTrialOutputLocalPath
(
trialJobId
);
if
(
subpath
!==
undefined
)
{
trialLocalPath
=
path
.
join
(
trialLocalPath
,
subpath
);
}
if
(
fs
.
existsSync
(
trialLocalPath
))
{
return
Promise
.
resolve
();
}
else
{
return
Promise
.
reject
(
new
Error
(
'
Trial local path not exist.
'
));
}
}
}
}
export
{
TrialDispatcher
};
export
{
TrialDispatcher
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment