Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
403195f0
Unverified
Commit
403195f0
authored
Jul 15, 2021
by
Yuge Zhang
Committed by
GitHub
Jul 15, 2021
Browse files
Merge branch 'master' into nn-meter
parents
99aa8226
a7278d2d
Changes
207
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
330 additions
and
98 deletions
+330
-98
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+66
-1
test/ut/sdk/test_compression_utils.py
test/ut/sdk/test_compression_utils.py
+1
-1
test/ut/sdk/test_model_speedup.py
test/ut/sdk/test_model_speedup.py
+142
-36
test/ut/tools/nnictl/mock/restful_server.py
test/ut/tools/nnictl/mock/restful_server.py
+1
-1
ts/nni_manager/common/experimentConfig.ts
ts/nni_manager/common/experimentConfig.ts
+1
-0
ts/nni_manager/common/manager.ts
ts/nni_manager/common/manager.ts
+2
-2
ts/nni_manager/common/trainingService.ts
ts/nni_manager/common/trainingService.ts
+2
-4
ts/nni_manager/common/utils.ts
ts/nni_manager/common/utils.ts
+10
-2
ts/nni_manager/core/nnimanager.ts
ts/nni_manager/core/nnimanager.ts
+20
-4
ts/nni_manager/core/test/mockedTrainingService.ts
ts/nni_manager/core/test/mockedTrainingService.ts
+2
-2
ts/nni_manager/package.json
ts/nni_manager/package.json
+2
-0
ts/nni_manager/rest_server/nniRestServer.ts
ts/nni_manager/rest_server/nniRestServer.ts
+11
-0
ts/nni_manager/rest_server/restHandler.ts
ts/nni_manager/rest_server/restHandler.ts
+34
-8
ts/nni_manager/rest_server/test/mockedNNIManager.ts
ts/nni_manager/rest_server/test/mockedNNIManager.ts
+2
-2
ts/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts
.../training_service/kubernetes/kubernetesTrainingService.ts
+3
-3
ts/nni_manager/training_service/local/localTrainingService.ts
...ni_manager/training_service/local/localTrainingService.ts
+14
-12
ts/nni_manager/training_service/pai/paiTrainingService.ts
ts/nni_manager/training_service/pai/paiTrainingService.ts
+9
-12
ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
...ng_service/remote_machine/remoteMachineTrainingService.ts
+3
-3
ts/nni_manager/training_service/reusable/routerTrainingService.ts
...anager/training_service/reusable/routerTrainingService.ts
+2
-2
ts/nni_manager/training_service/reusable/trialDispatcher.ts
ts/nni_manager/training_service/reusable/trialDispatcher.ts
+3
-3
No files found.
test/ut/retiarii/test_highlevel_apis.py
View file @
403195f0
...
...
@@ -5,7 +5,7 @@ from collections import Counter
import
nni.retiarii.nn.pytorch
as
nn
import
torch
import
torch.nn.functional
as
F
from
nni.retiarii
import
Sampler
,
basic_unit
from
nni.retiarii
import
InvalidMutation
,
Sampler
,
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.execution.python
import
_unpack_if_only_one
...
...
@@ -520,6 +520,45 @@ class GraphIR(unittest.TestCase):
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
def
test_nasbench201_cell
(
self
):
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
NasBench201Cell
([
lambda
x
,
y
:
nn
.
Linear
(
x
,
y
),
lambda
x
,
y
:
nn
.
Linear
(
x
,
y
,
bias
=
False
)
],
10
,
16
)
def
forward
(
self
,
x
):
return
self
.
cell
(
x
)
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
model
=
raw_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
2
,
10
)).
size
()
==
torch
.
Size
([
2
,
16
]))
def
test_autoactivation
(
self
):
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
act
=
nn
.
AutoActivation
()
def
forward
(
self
,
x
):
return
self
.
act
(
x
)
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
model
=
raw_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
2
,
10
)).
size
()
==
torch
.
Size
([
2
,
10
]))
class
Python
(
GraphIR
):
def
_get_converted_pytorch_model
(
self
,
model_ir
):
...
...
@@ -545,3 +584,29 @@ class Python(GraphIR):
@
unittest
.
skip
def
test_valuechoice_access_functional_expression
(
self
):
...
def
test_nasbench101_cell
(
self
):
# this is only supported in python engine for now.
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
NasBench101Cell
([
lambda
x
:
nn
.
Linear
(
x
,
x
),
lambda
x
:
nn
.
Linear
(
x
,
x
,
bias
=
False
)],
10
,
16
,
lambda
x
,
y
:
nn
.
Linear
(
x
,
y
),
max_num_nodes
=
5
,
max_num_edges
=
7
)
def
forward
(
self
,
x
):
return
self
.
cell
(
x
)
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
succeeded
=
0
sampler
=
RandomSampler
()
while
succeeded
<=
10
:
try
:
model
=
raw_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
succeeded
+=
1
except
InvalidMutation
:
continue
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
2
,
10
)).
size
()
==
torch
.
Size
([
2
,
16
]))
test/ut/sdk/test_compression_utils.py
View file @
403195f0
...
...
@@ -116,7 +116,7 @@ class AnalysisUtilsTest(TestCase):
pruner
.
export_model
(
ck_file
,
mask_file
)
pruner
.
_unwrap_model
()
# Fix the mask conflict
fixed_mask
,
_
=
fix_mask_conflict
(
mask_file
,
net
,
dummy_input
)
fixed_mask
=
fix_mask_conflict
(
mask_file
,
net
,
dummy_input
)
# use the channel dependency groud truth to check if
# fix the mask conflict successfully
...
...
test/ut/sdk/test_model_speedup.py
View file @
403195f0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
os
import
gc
import
psutil
import
sys
import
numpy
as
np
...
...
@@ -9,18 +11,20 @@ import torch
import
torchvision.models
as
models
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torchvision.models.vgg
import
vgg16
from
torchvision.models.vgg
import
vgg16
,
vgg11
from
torchvision.models.resnet
import
resnet18
from
torchvision.models.mobilenet
import
mobilenet_v2
import
unittest
from
unittest
import
TestCase
,
main
from
nni.compression.pytorch
import
ModelSpeedup
,
apply_compression_results
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
,
LevelPruner
from
nni.algorithms.compression.pytorch.pruning.weight_masker
import
WeightMasker
from
nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner
import
DependencyAwarePruner
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
BATCH_SIZE
=
2
# the relative distance
RELATIVE_THRESHOLD
=
0.01
...
...
@@ -105,6 +109,55 @@ class TransposeModel(torch.nn.Module):
return
x
class
TupleUnpack_backbone
(
nn
.
Module
):
def
__init__
(
self
,
width
):
super
(
TupleUnpack_backbone
,
self
).
__init__
()
self
.
model_backbone
=
mobilenet_v2
(
pretrained
=
False
,
width_mult
=
width
,
num_classes
=
3
)
def
forward
(
self
,
x
):
x1
=
self
.
model_backbone
.
features
[:
7
](
x
)
x2
=
self
.
model_backbone
.
features
[
7
:
14
](
x1
)
x3
=
self
.
model_backbone
.
features
[
14
:
18
](
x2
)
return
[
x1
,
x2
,
x3
]
class
TupleUnpack_FPN
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TupleUnpack_FPN
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
32
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
96
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
self
.
conv3
=
nn
.
Conv2d
(
320
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
# self.init_weights()
def
forward
(
self
,
inputs
):
"""Forward function."""
laterals
=
[]
laterals
.
append
(
self
.
conv1
(
inputs
[
0
]))
# inputs[0]==x1
laterals
.
append
(
self
.
conv2
(
inputs
[
1
]))
# inputs[1]==x2
laterals
.
append
(
self
.
conv3
(
inputs
[
2
]))
# inputs[2]==x3
return
laterals
class
TupleUnpack_Model
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TupleUnpack_Model
,
self
).
__init__
()
self
.
backbone
=
TupleUnpack_backbone
(
1.0
)
self
.
fpn
=
TupleUnpack_FPN
()
def
forward
(
self
,
x
):
x1
=
self
.
backbone
(
x
)
out
=
self
.
fpn
(
x1
)
return
out
dummy_input
=
torch
.
randn
(
2
,
1
,
28
,
28
)
SPARSITY
=
0.5
MODEL_FILE
,
MASK_FILE
=
'./11_model.pth'
,
'./l1_mask.pth'
...
...
@@ -129,6 +182,7 @@ def generate_random_sparsity(model):
'sparsity'
:
sparsity
})
return
cfg_list
def
generate_random_sparsity_v2
(
model
):
"""
Only select 50% layers to prune.
...
...
@@ -139,9 +193,10 @@ def generate_random_sparsity_v2(model):
if
np
.
random
.
uniform
(
0
,
1.0
)
>
0.5
:
sparsity
=
np
.
random
.
uniform
(
0.5
,
0.99
)
cfg_list
.
append
({
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
name
],
'sparsity'
:
sparsity
})
'sparsity'
:
sparsity
})
return
cfg_list
def
zero_bn_bias
(
model
):
with
torch
.
no_grad
():
for
name
,
module
in
model
.
named_modules
():
...
...
@@ -231,19 +286,6 @@ def channel_prune(model):
class
SpeedupTestCase
(
TestCase
):
def
test_speedup_vgg16
(
self
):
prune_model_l1
(
vgg16
())
model
=
vgg16
()
model
.
train
()
ms
=
ModelSpeedup
(
model
,
torch
.
randn
(
2
,
3
,
32
,
32
),
MASK_FILE
)
ms
.
speedup_model
()
orig_model
=
vgg16
()
assert
model
.
training
assert
model
.
features
[
2
].
out_channels
==
int
(
orig_model
.
features
[
2
].
out_channels
*
SPARSITY
)
assert
model
.
classifier
[
0
].
in_features
==
int
(
orig_model
.
classifier
[
0
].
in_features
*
SPARSITY
)
def
test_speedup_bigmodel
(
self
):
prune_model_l1
(
BigModel
())
...
...
@@ -253,7 +295,7 @@ class SpeedupTestCase(TestCase):
mask_out
=
model
(
dummy_input
)
model
.
train
()
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
)
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
assert
model
.
training
...
...
@@ -289,7 +331,7 @@ class SpeedupTestCase(TestCase):
new_model
=
TransposeModel
()
state_dict
=
torch
.
load
(
MODEL_FILE
)
new_model
.
load_state_dict
(
state_dict
)
ms
=
ModelSpeedup
(
new_model
,
dummy_input
,
MASK_FILE
)
ms
=
ModelSpeedup
(
new_model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
zero_bn_bias
(
ori_model
)
zero_bn_bias
(
new_model
)
...
...
@@ -297,26 +339,38 @@ class SpeedupTestCase(TestCase):
new_out
=
new_model
(
dummy_input
)
ori_sum
=
torch
.
sum
(
ori_out
)
speeded_sum
=
torch
.
sum
(
new_out
)
print
(
'Tanspose Speedup Test: ori_sum={} speedup_sum={}'
.
format
(
ori_sum
,
speeded_sum
))
print
(
'Tanspose Speedup Test: ori_sum={} speedup_sum={}'
.
format
(
ori_sum
,
speeded_sum
))
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
# FIXME: This test case might fail randomly, no idea why
# Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282
def
test_speedup_integration_small
(
self
):
model_list
=
[
'resnet18'
,
'mobilenet_v2'
,
'alexnet'
]
self
.
speedup_integration
(
model_list
)
def
test_speedup_integration_big
(
self
):
model_list
=
[
'vgg11'
,
'vgg16'
,
'resnet34'
,
'squeezenet1_1'
,
'densenet121'
,
'resnet50'
,
'wide_resnet50_2'
]
mem_info
=
psutil
.
virtual_memory
()
ava_gb
=
mem_info
.
available
/
1024.0
/
1024
/
1024
print
(
'Avaliable memory size: %.2f GB'
%
ava_gb
)
if
ava_gb
<
8.0
:
# memory size is too small that we may run into an OOM exception
# Skip this test in the pipeline test due to memory limitation
return
self
.
speedup_integration
(
model_list
)
def
test_speedup_integration
(
self
):
# skip this test on windows(7GB mem available) due to memory limit
def
speedup_integration
(
self
,
model_list
,
speedup_cfg
=
None
):
# Note: hack trick, may be updated in the future
if
'win'
in
sys
.
platform
or
'Win'
in
sys
.
platform
:
print
(
'Skip test_speedup_integration on windows due to memory limit!'
)
return
Gen_cfg_funcs
=
[
generate_random_sparsity
,
generate_random_sparsity_v2
]
for
model_name
in
[
'resnet18'
,
'mobilenet_v2'
,
'squeezenet1_1'
,
'densenet121'
,
'densenet169'
,
# 'inception_v3' inception is too large and may fail the pipeline
'resnet50'
]:
#
for model_name in [
'vgg16',
'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121'
,
#
# 'inception_v3' inception is too large and may fail the pipeline
#
'resnet50']:
for
model_name
in
model_list
:
for
gen_cfg_func
in
Gen_cfg_funcs
:
kwargs
=
{
'pretrained'
:
True
...
...
@@ -334,7 +388,10 @@ class SpeedupTestCase(TestCase):
speedup_model
.
eval
()
# random generate the prune config for the pruner
cfgs
=
gen_cfg_func
(
net
)
print
(
"Testing {} with compression config
\n
{}"
.
format
(
model_name
,
cfgs
))
print
(
"Testing {} with compression config
\n
{}"
.
format
(
model_name
,
cfgs
))
if
len
(
cfgs
)
==
0
:
continue
pruner
=
L1FilterPruner
(
net
,
cfgs
)
pruner
.
compress
()
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
...
...
@@ -345,7 +402,10 @@ class SpeedupTestCase(TestCase):
zero_bn_bias
(
speedup_model
)
data
=
torch
.
ones
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
ms
=
ModelSpeedup
(
speedup_model
,
data
,
MASK_FILE
)
if
speedup_cfg
is
None
:
speedup_cfg
=
{}
ms
=
ModelSpeedup
(
speedup_model
,
data
,
MASK_FILE
,
confidence
=
2
,
**
speedup_cfg
)
ms
.
speedup_model
()
speedup_model
.
eval
()
...
...
@@ -355,12 +415,13 @@ class SpeedupTestCase(TestCase):
ori_sum
=
torch
.
sum
(
ori_out
).
item
()
speeded_sum
=
torch
.
sum
(
speeded_out
).
item
()
print
(
'Sum of the output of %s (before speedup):'
%
model_name
,
ori_sum
)
print
(
'Sum of the output of %s (after speedup):'
%
model_name
,
speeded_sum
)
model_name
,
ori_sum
)
print
(
'Sum of the output of %s (after
speedup):'
%
model_name
,
speeded_sum
)
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
print
(
"Collecting Garbage"
)
gc
.
collect
(
2
)
def
test_channel_prune
(
self
):
orig_net
=
resnet18
(
num_classes
=
10
).
to
(
device
)
...
...
@@ -378,7 +439,7 @@ class SpeedupTestCase(TestCase):
net
.
eval
()
data
=
torch
.
randn
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
ms
=
ModelSpeedup
(
net
,
data
,
MASK_FILE
)
ms
=
ModelSpeedup
(
net
,
data
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
ms
.
bound_model
(
data
)
...
...
@@ -391,11 +452,56 @@ class SpeedupTestCase(TestCase):
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
def
test_speedup_tupleunpack
(
self
):
"""This test is reported in issue3645"""
model
=
TupleUnpack_Model
()
cfg_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.5
}]
dummy_input
=
torch
.
rand
(
2
,
3
,
224
,
224
)
pruner
=
L1FilterPruner
(
model
,
cfg_list
)
pruner
.
compress
()
model
(
dummy_input
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
def
test_finegrained_speedup
(
self
):
""" Test the speedup on the fine-grained sparsity"""
class
MLP
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MLP
,
self
).
__init__
()
self
.
fc1
=
nn
.
Linear
(
1024
,
1024
)
self
.
fc2
=
nn
.
Linear
(
1024
,
1024
)
self
.
fc3
=
nn
.
Linear
(
1024
,
512
)
self
.
fc4
=
nn
.
Linear
(
512
,
10
)
def
forward
(
self
,
x
):
x
=
x
.
view
(
-
1
,
1024
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc3
(
x
)
x
=
self
.
fc4
(
x
)
return
x
model
=
MLP
().
to
(
device
)
dummy_input
=
torch
.
rand
(
16
,
1
,
32
,
32
).
to
(
device
)
cfg_list
=
[{
'op_types'
:
[
'Linear'
],
'sparsity'
:
0.99
}]
pruner
=
LevelPruner
(
model
,
cfg_list
)
pruner
.
compress
()
print
(
'Original Arch'
)
print
(
model
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
_unwrap_model
()
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
4
)
ms
.
speedup_model
()
print
(
"Fine-grained speeduped model"
)
print
(
model
)
def
tearDown
(
self
):
if
os
.
path
.
exists
(
MODEL_FILE
):
os
.
remove
(
MODEL_FILE
)
if
os
.
path
.
exists
(
MASK_FILE
):
os
.
remove
(
MASK_FILE
)
# GC to release memory
gc
.
collect
(
2
)
if
__name__
==
'__main__'
:
...
...
test/ut/tools/nnictl/mock/restful_server.py
View file @
403195f0
...
...
@@ -156,7 +156,7 @@ def mock_get_latest_metric_data():
def
mock_get_trial_log
():
responses
.
add
(
responses
.
DELETE
,
'http://localhost:8080/api/v1/nni/trial-
log
/:id/:
typ
e'
,
responses
.
DELETE
,
'http://localhost:8080/api/v1/nni/trial-
file
/:id/:
filenam
e'
,
json
=
{
"status"
:
"RUNNING"
,
"errors"
:[]},
status
=
200
,
content_type
=
'application/json'
,
...
...
ts/nni_manager/common/experimentConfig.ts
View file @
403195f0
...
...
@@ -161,6 +161,7 @@ export interface ExperimentConfig {
trialConcurrency
:
number
;
trialGpuNumber
?:
number
;
maxExperimentDuration
?:
string
;
maxTrialDuration
?:
string
;
maxTrialNumber
?:
number
;
nniManagerIp
?:
string
;
//useAnnotation: boolean; // dealed inside nnictl
...
...
ts/nni_manager/common/manager.ts
View file @
403195f0
...
...
@@ -4,7 +4,7 @@
'
use strict
'
;
import
{
MetricDataRecord
,
MetricType
,
TrialJobInfo
}
from
'
./datastore
'
;
import
{
TrialJobStatus
,
LogType
}
from
'
./trainingService
'
;
import
{
TrialJobStatus
}
from
'
./trainingService
'
;
import
{
ExperimentConfig
}
from
'
./experimentConfig
'
;
type
ProfileUpdateType
=
'
TRIAL_CONCURRENCY
'
|
'
MAX_EXEC_DURATION
'
|
'
SEARCH_SPACE
'
|
'
MAX_TRIAL_NUM
'
;
...
...
@@ -59,7 +59,7 @@ abstract class Manager {
public
abstract
getMetricDataByRange
(
minSeqId
:
number
,
maxSeqId
:
number
):
Promise
<
MetricDataRecord
[]
>
;
public
abstract
getLatestMetricData
():
Promise
<
MetricDataRecord
[]
>
;
public
abstract
getTrial
Log
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
;
public
abstract
getTrial
File
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
Buffer
|
string
>
;
public
abstract
getTrialJobStatistics
():
Promise
<
TrialJobStatistics
[]
>
;
public
abstract
getStatus
():
NNIManagerStatus
;
...
...
ts/nni_manager/common/trainingService.ts
View file @
403195f0
...
...
@@ -8,8 +8,6 @@
*/
type
TrialJobStatus
=
'
UNKNOWN
'
|
'
WAITING
'
|
'
RUNNING
'
|
'
SUCCEEDED
'
|
'
FAILED
'
|
'
USER_CANCELED
'
|
'
SYS_CANCELED
'
|
'
EARLY_STOPPED
'
;
type
LogType
=
'
TRIAL_LOG
'
|
'
TRIAL_STDOUT
'
|
'
TRIAL_ERROR
'
;
interface
TrainingServiceMetadata
{
readonly
key
:
string
;
readonly
value
:
string
;
...
...
@@ -81,7 +79,7 @@ abstract class TrainingService {
public
abstract
submitTrialJob
(
form
:
TrialJobApplicationForm
):
Promise
<
TrialJobDetail
>
;
public
abstract
updateTrialJob
(
trialJobId
:
string
,
form
:
TrialJobApplicationForm
):
Promise
<
TrialJobDetail
>
;
public
abstract
cancelTrialJob
(
trialJobId
:
string
,
isEarlyStopped
?:
boolean
):
Promise
<
void
>
;
public
abstract
getTrial
Log
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
;
public
abstract
getTrial
File
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
Buffer
|
string
>
;
public
abstract
setClusterMetadata
(
key
:
string
,
value
:
string
):
Promise
<
void
>
;
public
abstract
getClusterMetadata
(
key
:
string
):
Promise
<
string
>
;
public
abstract
getTrialOutputLocalPath
(
trialJobId
:
string
):
Promise
<
string
>
;
...
...
@@ -103,5 +101,5 @@ class NNIManagerIpConfig {
export
{
TrainingService
,
TrainingServiceError
,
TrialJobStatus
,
TrialJobApplicationForm
,
TrainingServiceMetadata
,
TrialJobDetail
,
TrialJobMetric
,
HyperParameters
,
NNIManagerIpConfig
,
LogType
NNIManagerIpConfig
};
ts/nni_manager/common/utils.ts
View file @
403195f0
...
...
@@ -223,7 +223,7 @@ let cachedIpv4Address: string | null = null;
/**
* Get IPv4 address of current machine.
*/
function
getIPV4Address
():
string
{
async
function
getIPV4Address
():
Promise
<
string
>
{
if
(
cachedIpv4Address
!==
null
)
{
return
cachedIpv4Address
;
}
...
...
@@ -232,12 +232,20 @@ function getIPV4Address(): string {
// since udp is connectionless, this does not send actual packets.
const
socket
=
dgram
.
createSocket
(
'
udp4
'
);
socket
.
connect
(
1
,
'
192.0.2.0
'
);
cachedIpv4Address
=
socket
.
address
().
address
;
for
(
let
i
=
0
;
i
<
10
;
i
++
)
{
// wait the system to initialize "connection"
await
yield_
();
try
{
cachedIpv4Address
=
socket
.
address
().
address
;
}
catch
(
error
)
{
/* retry */
}
}
cachedIpv4Address
=
socket
.
address
().
address
;
// if it still fails, throw the error
socket
.
close
();
return
cachedIpv4Address
;
}
async
function
yield_
():
Promise
<
void
>
{
/* trigger the scheduler, do nothing */
}
/**
* Get the status of canceled jobs according to the hint isEarlyStopped
*/
...
...
ts/nni_manager/core/nnimanager.ts
View file @
403195f0
...
...
@@ -19,7 +19,7 @@ import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/exp
import
{
ExperimentManager
}
from
'
../common/experimentManager
'
;
import
{
TensorboardManager
}
from
'
../common/tensorboardManager
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
TrialJobStatus
,
LogType
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
TrialJobStatus
}
from
'
../common/trainingService
'
;
import
{
delay
,
getCheckpointDir
,
getExperimentRootDir
,
getLogDir
,
getMsgDispatcherCommand
,
mkDirP
,
getTunerProc
,
getLogLevel
,
isAlive
,
killPid
}
from
'
../common/utils
'
;
import
{
...
...
@@ -189,7 +189,6 @@ class NNIManager implements Manager {
this
.
log
.
debug
(
`dispatcher command:
${
dispatcherCommand
}
`
);
const
checkpointDir
:
string
=
await
this
.
createCheckpointDir
();
this
.
setupTuner
(
dispatcherCommand
,
undefined
,
'
start
'
,
checkpointDir
);
this
.
setStatus
(
'
RUNNING
'
);
await
this
.
storeExperimentProfile
();
this
.
run
().
catch
((
err
:
Error
)
=>
{
...
...
@@ -403,8 +402,8 @@ class NNIManager implements Manager {
// FIXME: unit test
}
public
async
getTrial
Log
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
{
return
this
.
trainingService
.
getTrial
Log
(
trialJobId
,
logTyp
e
);
public
async
getTrial
File
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
Buffer
|
string
>
{
return
this
.
trainingService
.
getTrial
File
(
trialJobId
,
fileNam
e
);
}
public
getExperimentProfile
():
Promise
<
ExperimentProfile
>
{
...
...
@@ -433,6 +432,11 @@ class NNIManager implements Manager {
return
(
value
===
undefined
?
Infinity
:
value
);
}
private
get
maxTrialDuration
():
number
{
const
value
=
this
.
experimentProfile
.
params
.
maxTrialDuration
;
return
(
value
===
undefined
?
Infinity
:
toSeconds
(
value
));
}
private
async
initTrainingService
(
config
:
ExperimentConfig
):
Promise
<
TrainingService
>
{
let
platform
:
string
;
if
(
Array
.
isArray
(
config
.
trainingService
))
{
...
...
@@ -539,6 +543,17 @@ class NNIManager implements Manager {
}
}
private
async
stopTrialJobIfOverMaxDurationTimer
(
trialJobId
:
string
):
Promise
<
void
>
{
const
trialJobDetail
:
TrialJobDetail
|
undefined
=
this
.
trialJobs
.
get
(
trialJobId
);
if
(
undefined
!==
trialJobDetail
&&
trialJobDetail
.
status
===
'
RUNNING
'
&&
trialJobDetail
.
startTime
!==
undefined
){
const
isEarlyStopped
=
true
;
await
this
.
trainingService
.
cancelTrialJob
(
trialJobId
,
isEarlyStopped
);
this
.
log
.
info
(
`Trial job
${
trialJobId
}
has stoped because it is over maxTrialDuration.`
);
}
}
private
async
requestTrialJobsStatus
():
Promise
<
number
>
{
let
finishedTrialJobNum
:
number
=
0
;
if
(
this
.
dispatcher
===
undefined
)
{
...
...
@@ -662,6 +677,7 @@ class NNIManager implements Manager {
this
.
currSubmittedTrialNum
++
;
this
.
log
.
info
(
'
submitTrialJob: form:
'
,
form
);
const
trialJobDetail
:
TrialJobDetail
=
await
this
.
trainingService
.
submitTrialJob
(
form
);
setTimeout
(
async
()
=>
this
.
stopTrialJobIfOverMaxDurationTimer
(
trialJobDetail
.
id
),
1000
*
this
.
maxTrialDuration
);
const
Snapshot
:
TrialJobDetail
=
Object
.
assign
({},
trialJobDetail
);
await
this
.
storeExperimentProfile
();
this
.
trialJobs
.
set
(
trialJobDetail
.
id
,
Snapshot
);
...
...
ts/nni_manager/core/test/mockedTrainingService.ts
View file @
403195f0
...
...
@@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred';
import
{
Provider
}
from
'
typescript-ioc
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
LogType
}
from
'
../../common/trainingService
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
const
testTrainingServiceProvider
:
Provider
=
{
get
:
()
=>
{
return
new
MockedTrainingService
();
}
...
...
@@ -63,7 +63,7 @@ class MockedTrainingService extends TrainingService {
return
deferred
.
promise
;
}
public
getTrial
Log
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
{
public
getTrial
File
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
string
>
{
throw
new
MethodNotImplementedError
();
}
...
...
ts/nni_manager/package.json
View file @
403195f0
...
...
@@ -15,6 +15,7 @@
"child-process-promise"
:
"^2.2.1"
,
"express"
:
"^4.17.1"
,
"express-joi-validator"
:
"^2.0.1"
,
"http-proxy"
:
"^1.18.1"
,
"ignore"
:
"^5.1.8"
,
"js-base64"
:
"^3.6.1"
,
"kubernetes-client"
:
"^6.12.1"
,
...
...
@@ -37,6 +38,7 @@
"@types/chai-as-promised"
:
"^7.1.0"
,
"@types/express"
:
"^4.17.2"
,
"@types/glob"
:
"^7.1.3"
,
"@types/http-proxy"
:
"^1.17.7"
,
"@types/js-base64"
:
"^3.3.1"
,
"@types/js-yaml"
:
"^4.0.1"
,
"@types/lockfile"
:
"^1.0.0"
,
...
...
ts/nni_manager/rest_server/nniRestServer.ts
View file @
403195f0
...
...
@@ -5,6 +5,7 @@
import
*
as
bodyParser
from
'
body-parser
'
;
import
*
as
express
from
'
express
'
;
import
*
as
httpProxy
from
'
http-proxy
'
;
import
*
as
path
from
'
path
'
;
import
*
as
component
from
'
../common/component
'
;
import
{
RestServer
}
from
'
../common/restServer
'
...
...
@@ -21,6 +22,7 @@ import { getAPIRootUrl } from '../common/experimentStartupInfo';
@
component
.
Singleton
export
class
NNIRestServer
extends
RestServer
{
private
readonly
LOGS_ROOT_URL
:
string
=
'
/logs
'
;
protected
netronProxy
:
any
=
null
;
protected
API_ROOT_URL
:
string
=
'
/api/v1/nni
'
;
/**
...
...
@@ -29,6 +31,7 @@ export class NNIRestServer extends RestServer {
constructor
()
{
super
();
this
.
API_ROOT_URL
=
getAPIRootUrl
();
this
.
netronProxy
=
httpProxy
.
createProxyServer
();
}
/**
...
...
@@ -39,6 +42,14 @@ export class NNIRestServer extends RestServer {
this
.
app
.
use
(
bodyParser
.
json
({
limit
:
'
50mb
'
}));
this
.
app
.
use
(
this
.
API_ROOT_URL
,
createRestHandler
(
this
));
this
.
app
.
use
(
this
.
LOGS_ROOT_URL
,
express
.
static
(
getLogDir
()));
this
.
app
.
all
(
'
/netron/*
'
,
(
req
:
express
.
Request
,
res
:
express
.
Response
)
=>
{
delete
req
.
headers
.
host
;
req
.
url
=
req
.
url
.
replace
(
'
/netron
'
,
'
/
'
);
this
.
netronProxy
.
web
(
req
,
res
,
{
changeOrigin
:
true
,
target
:
'
https://netron.app
'
});
});
this
.
app
.
get
(
'
*
'
,
(
req
:
express
.
Request
,
res
:
express
.
Response
)
=>
{
res
.
sendFile
(
path
.
resolve
(
'
static/index.html
'
));
});
...
...
ts/nni_manager/rest_server/restHandler.ts
View file @
403195f0
...
...
@@ -19,7 +19,7 @@ import { NNIRestServer } from './nniRestServer';
import
{
getVersion
}
from
'
../common/utils
'
;
import
{
MetricType
}
from
'
../common/datastore
'
;
import
{
ProfileUpdateType
}
from
'
../common/manager
'
;
import
{
LogType
,
TrialJobStatus
}
from
'
../common/trainingService
'
;
import
{
TrialJobStatus
}
from
'
../common/trainingService
'
;
const
expressJoi
=
require
(
'
express-joi-validator
'
);
...
...
@@ -53,6 +53,7 @@ class NNIRestHandler {
this
.
version
(
router
);
this
.
checkStatus
(
router
);
this
.
getExperimentProfile
(
router
);
this
.
getExperimentMetadata
(
router
);
this
.
updateExperimentProfile
(
router
);
this
.
importData
(
router
);
this
.
getImportedData
(
router
);
...
...
@@ -66,7 +67,7 @@ class NNIRestHandler {
this
.
getMetricData
(
router
);
this
.
getMetricDataByRange
(
router
);
this
.
getLatestMetricData
(
router
);
this
.
getTrial
Log
(
router
);
this
.
getTrial
File
(
router
);
this
.
exportData
(
router
);
this
.
getExperimentsInfo
(
router
);
this
.
startTensorboardTask
(
router
);
...
...
@@ -296,13 +297,20 @@ class NNIRestHandler {
});
}
private
getTrialLog
(
router
:
Router
):
void
{
router
.
get
(
'
/trial-log/:id/:type
'
,
async
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
nniManager
.
getTrialLog
(
req
.
params
.
id
,
req
.
params
.
type
as
LogType
).
then
((
log
:
string
)
=>
{
if
(
log
===
''
)
{
log
=
'
No logs available.
'
private
getTrialFile
(
router
:
Router
):
void
{
router
.
get
(
'
/trial-file/:id/:filename
'
,
async
(
req
:
Request
,
res
:
Response
)
=>
{
let
encoding
:
string
|
null
=
null
;
const
filename
=
req
.
params
.
filename
;
if
(
!
filename
.
includes
(
'
.
'
)
||
filename
.
match
(
/.*
\.(
txt|log
)
/g
))
{
encoding
=
'
utf8
'
;
}
this
.
nniManager
.
getTrialFile
(
req
.
params
.
id
,
filename
).
then
((
content
:
Buffer
|
string
)
=>
{
if
(
content
instanceof
Buffer
)
{
res
.
header
(
'
Content-Type
'
,
'
application/octet-stream
'
);
}
else
if
(
content
===
''
)
{
content
=
`
${
filename
}
is empty.`
;
}
res
.
send
(
log
);
res
.
send
(
content
);
}).
catch
((
err
:
Error
)
=>
{
this
.
handleError
(
err
,
res
);
});
...
...
@@ -319,6 +327,24 @@ class NNIRestHandler {
});
}
private
getExperimentMetadata
(
router
:
Router
):
void
{
router
.
get
(
'
/experiment-metadata
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
Promise
.
all
([
this
.
nniManager
.
getExperimentProfile
(),
this
.
experimentsManager
.
getExperimentsInfo
()
]).
then
(([
profile
,
experimentInfo
])
=>
{
for
(
const
info
of
experimentInfo
as
any
)
{
if
(
info
.
id
===
profile
.
id
)
{
res
.
send
(
info
);
break
;
}
}
}).
catch
((
err
:
Error
)
=>
{
this
.
handleError
(
err
,
res
);
});
});
}
private
getExperimentsInfo
(
router
:
Router
):
void
{
router
.
get
(
'
/experiments-info
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
experimentsManager
.
getExperimentsInfo
().
then
((
experimentInfo
:
JSON
)
=>
{
...
...
ts/nni_manager/rest_server/test/mockedNNIManager.ts
View file @
403195f0
...
...
@@ -13,7 +13,7 @@ import {
TrialJobStatistics
,
NNIManagerStatus
}
from
'
../../common/manager
'
;
import
{
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobStatus
,
LogType
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobStatus
}
from
'
../../common/trainingService
'
;
export
const
testManagerProvider
:
Provider
=
{
...
...
@@ -129,7 +129,7 @@ export class MockedNNIManager extends Manager {
public
getLatestMetricData
():
Promise
<
MetricDataRecord
[]
>
{
throw
new
MethodNotImplementedError
();
}
public
getTrial
Log
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
{
public
getTrial
File
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
string
>
{
throw
new
MethodNotImplementedError
();
}
public
getExperimentProfile
():
Promise
<
ExperimentProfile
>
{
...
...
ts/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts
View file @
403195f0
...
...
@@ -14,7 +14,7 @@ import {getExperimentId} from '../../common/experimentStartupInfo';
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
NNIManagerIpConfig
,
TrialJobDetail
,
TrialJobMetric
,
LogType
NNIManagerIpConfig
,
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
import
{
delay
,
getExperimentRootDir
,
getIPV4Address
,
getJobCancelStatus
,
getVersion
,
uniqueString
}
from
'
../../common/utils
'
;
import
{
AzureStorageClientUtility
}
from
'
./azureStorageClientUtils
'
;
...
...
@@ -99,7 +99,7 @@ abstract class KubernetesTrainingService {
return
Promise
.
resolve
(
kubernetesTrialJob
);
}
public
async
getTrial
Log
(
_trialJobId
:
string
,
_
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrial
File
(
_trialJobId
:
string
,
_
filename
:
string
):
Promise
<
string
|
Buffer
>
{
throw
new
MethodNotImplementedError
();
}
...
...
@@ -277,7 +277,7 @@ abstract class KubernetesTrainingService {
if
(
gpuNum
===
0
)
{
nvidiaScript
=
'
export CUDA_VISIBLE_DEVICES=
'
;
}
const
nniManagerIp
:
string
=
this
.
nniManagerIpConfig
?
this
.
nniManagerIpConfig
.
nniManagerIp
:
getIPV4Address
();
const
nniManagerIp
:
string
=
this
.
nniManagerIpConfig
?
this
.
nniManagerIpConfig
.
nniManagerIp
:
await
getIPV4Address
();
const
version
:
string
=
this
.
versionCheck
?
await
getVersion
()
:
''
;
const
runScript
:
string
=
String
.
Format
(
kubernetesScriptFormat
,
...
...
ts/nni_manager/training_service/local/localTrainingService.ts
View file @
403195f0
...
...
@@ -13,7 +13,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo';
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
HyperParameters
,
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
TrialJobStatus
,
LogType
TrialJobDetail
,
TrialJobMetric
,
TrialJobStatus
}
from
'
../../common/trainingService
'
;
import
{
delay
,
generateParamFileName
,
getExperimentRootDir
,
getJobCancelStatus
,
getNewLine
,
isAlive
,
uniqueString
...
...
@@ -170,18 +170,20 @@ class LocalTrainingService implements TrainingService {
return
trialJob
;
}
public
async
getTrialLog
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
{
let
logPath
:
string
;
if
(
logType
===
'
TRIAL_LOG
'
)
{
logPath
=
path
.
join
(
this
.
rootDir
,
'
trials
'
,
trialJobId
,
'
trial.log
'
);
}
else
if
(
logType
===
'
TRIAL_STDOUT
'
){
logPath
=
path
.
join
(
this
.
rootDir
,
'
trials
'
,
trialJobId
,
'
stdout
'
);
}
else
if
(
logType
===
'
TRIAL_ERROR
'
)
{
logPath
=
path
.
join
(
this
.
rootDir
,
'
trials
'
,
trialJobId
,
'
stderr
'
);
}
else
{
throw
new
Error
(
'
unexpected log type
'
);
public
async
getTrialFile
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
string
|
Buffer
>
{
// check filename here for security
if
(
!
[
'
trial.log
'
,
'
stderr
'
,
'
model.onnx
'
,
'
stdout
'
].
includes
(
fileName
))
{
throw
new
Error
(
`File unaccessible:
${
fileName
}
`
);
}
let
encoding
:
string
|
null
=
null
;
if
(
!
fileName
.
includes
(
'
.
'
)
||
fileName
.
match
(
/.*
\.(
txt|log
)
/g
))
{
encoding
=
'
utf8
'
;
}
const
logPath
=
path
.
join
(
this
.
rootDir
,
'
trials
'
,
trialJobId
,
fileName
);
if
(
!
fs
.
existsSync
(
logPath
))
{
throw
new
Error
(
`File not found:
${
logPath
}
`
);
}
return
fs
.
promises
.
readFile
(
logPath
,
'
utf8
'
);
return
fs
.
promises
.
readFile
(
logPath
,
{
encoding
:
encoding
as
any
}
);
}
public
addTrialJobMetricListener
(
listener
:
(
metric
:
TrialJobMetric
)
=>
void
):
void
{
...
...
ts/nni_manager/training_service/pai/paiTrainingService.ts
View file @
403195f0
...
...
@@ -15,7 +15,7 @@ import { getLogger, Logger } from '../../common/log';
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
HyperParameters
,
NNIManagerIpConfig
,
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
LogType
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
import
{
delay
}
from
'
../../common/utils
'
;
import
{
ExperimentConfig
,
OpenpaiConfig
,
flattenConfig
,
toMegaBytes
}
from
'
../../common/experimentConfig
'
;
...
...
@@ -23,10 +23,7 @@ import { PAIJobInfoCollector } from './paiJobInfoCollector';
import
{
PAIJobRestServer
}
from
'
./paiJobRestServer
'
;
import
{
PAITrialJobDetail
,
PAI_TRIAL_COMMAND_FORMAT
}
from
'
./paiConfig
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
{
generateParamFileName
,
getIPV4Address
,
uniqueString
}
from
'
../../common/utils
'
;
import
{
generateParamFileName
,
getIPV4Address
,
uniqueString
}
from
'
../../common/utils
'
;
import
{
CONTAINER_INSTALL_NNI_SHELL_FORMAT
}
from
'
../common/containerJobData
'
;
import
{
execMkdir
,
validateCodeDir
,
execCopydir
}
from
'
../common/util
'
;
...
...
@@ -127,7 +124,7 @@ class PAITrainingService implements TrainingService {
return
jobs
;
}
public
async
getTrial
Log
(
_trialJobId
:
string
,
_
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrial
File
(
_trialJobId
:
string
,
_
fileName
:
string
):
Promise
<
string
|
Buffer
>
{
throw
new
MethodNotImplementedError
();
}
...
...
@@ -332,7 +329,7 @@ class PAITrainingService implements TrainingService {
return
trialJobDetail
;
}
private
generateNNITrialCommand
(
trialJobDetail
:
PAITrialJobDetail
,
command
:
string
):
string
{
private
async
generateNNITrialCommand
(
trialJobDetail
:
PAITrialJobDetail
,
command
:
string
):
Promise
<
string
>
{
const
containerNFSExpCodeDir
=
`
${
this
.
config
.
containerStorageMountPoint
}
/
${
this
.
experimentId
}
/nni-code`
;
const
containerWorkingDir
:
string
=
`
${
this
.
config
.
containerStorageMountPoint
}
/
${
this
.
experimentId
}
/
${
trialJobDetail
.
id
}
`
;
const
nniPaiTrialCommand
:
string
=
String
.
Format
(
...
...
@@ -345,7 +342,7 @@ class PAITrainingService implements TrainingService {
false
,
// multi-phase
containerNFSExpCodeDir
,
command
,
this
.
config
.
nniManagerIp
||
getIPV4Address
(),
this
.
config
.
nniManagerIp
||
await
getIPV4Address
(),
this
.
paiRestServerPort
,
this
.
nniVersion
,
this
.
logCollection
...
...
@@ -356,7 +353,7 @@ class PAITrainingService implements TrainingService {
}
private
generateJobConfigInYamlFormat
(
trialJobDetail
:
PAITrialJobDetail
):
any
{
private
async
generateJobConfigInYamlFormat
(
trialJobDetail
:
PAITrialJobDetail
):
Promise
<
any
>
{
const
jobName
=
`nni_exp_
${
this
.
experimentId
}
_trial_
${
trialJobDetail
.
id
}
`
let
nniJobConfig
:
any
=
undefined
;
...
...
@@ -367,7 +364,7 @@ class PAITrainingService implements TrainingService {
// Each command will be formatted to NNI style
for
(
const
taskRoleIndex
in
nniJobConfig
.
taskRoles
)
{
const
commands
=
nniJobConfig
.
taskRoles
[
taskRoleIndex
].
commands
const
nniTrialCommand
=
this
.
generateNNITrialCommand
(
trialJobDetail
,
commands
.
join
(
"
&&
"
).
replace
(
/
([
"'$`
\\])
/g
,
'
\\
$1
'
));
const
nniTrialCommand
=
await
this
.
generateNNITrialCommand
(
trialJobDetail
,
commands
.
join
(
"
&&
"
).
replace
(
/
([
"'$`
\\])
/g
,
'
\\
$1
'
));
nniJobConfig
.
taskRoles
[
taskRoleIndex
].
commands
=
[
nniTrialCommand
]
}
...
...
@@ -399,7 +396,7 @@ class PAITrainingService implements TrainingService {
memoryMB
:
toMegaBytes
(
this
.
config
.
trialMemorySize
)
},
commands
:
[
this
.
generateNNITrialCommand
(
trialJobDetail
,
this
.
config
.
trialCommand
)
await
this
.
generateNNITrialCommand
(
trialJobDetail
,
this
.
config
.
trialCommand
)
]
}
},
...
...
@@ -456,7 +453,7 @@ class PAITrainingService implements TrainingService {
}
//Generate Job Configuration in yaml format
const
paiJobConfig
=
this
.
generateJobConfigInYamlFormat
(
trialJobDetail
);
const
paiJobConfig
=
await
this
.
generateJobConfigInYamlFormat
(
trialJobDetail
);
this
.
log
.
debug
(
paiJobConfig
);
// Step 2. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
...
...
ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
View file @
403195f0
...
...
@@ -16,7 +16,7 @@ import { getLogger, Logger } from '../../common/log';
import
{
ObservableTimer
}
from
'
../../common/observableTimer
'
;
import
{
HyperParameters
,
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
LogType
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
import
{
delay
,
generateParamFileName
,
getExperimentRootDir
,
getIPV4Address
,
getJobCancelStatus
,
...
...
@@ -204,7 +204,7 @@ class RemoteMachineTrainingService implements TrainingService {
* @param _trialJobId ID of trial job
* @param _logType 'TRIAL_LOG' | 'TRIAL_STDERR'
*/
public
async
getTrial
Log
(
_trialJobId
:
string
,
_
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrial
File
(
_trialJobId
:
string
,
_
fileName
:
string
):
Promise
<
string
|
Buffer
>
{
throw
new
MethodNotImplementedError
();
}
...
...
@@ -491,7 +491,7 @@ class RemoteMachineTrainingService implements TrainingService {
cudaVisible
=
`CUDA_VISIBLE_DEVICES=" "`
;
}
}
const
nniManagerIp
:
string
=
this
.
config
.
nniManagerIp
?
this
.
config
.
nniManagerIp
:
getIPV4Address
();
const
nniManagerIp
:
string
=
this
.
config
.
nniManagerIp
?
this
.
config
.
nniManagerIp
:
await
getIPV4Address
();
if
(
this
.
remoteRestServerPort
===
undefined
)
{
const
restServer
:
RemoteMachineJobRestServer
=
component
.
get
(
RemoteMachineJobRestServer
);
this
.
remoteRestServerPort
=
restServer
.
clusterRestServerPort
;
...
...
ts/nni_manager/training_service/reusable/routerTrainingService.ts
View file @
403195f0
...
...
@@ -6,7 +6,7 @@
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
ExperimentConfig
,
RemoteConfig
,
OpenpaiConfig
}
from
'
../../common/experimentConfig
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
LogType
}
from
'
../../common/trainingService
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
import
{
delay
}
from
'
../../common/utils
'
;
import
{
PAITrainingService
}
from
'
../pai/paiTrainingService
'
;
import
{
RemoteMachineTrainingService
}
from
'
../remote_machine/remoteMachineTrainingService
'
;
...
...
@@ -52,7 +52,7 @@ class RouterTrainingService implements TrainingService {
return
await
this
.
internalTrainingService
.
getTrialJob
(
trialJobId
);
}
public
async
getTrial
Log
(
_trialJobId
:
string
,
_
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrial
File
(
_trialJobId
:
string
,
_
fileName
:
string
):
Promise
<
string
|
Buffer
>
{
throw
new
MethodNotImplementedError
();
}
...
...
ts/nni_manager/training_service/reusable/trialDispatcher.ts
View file @
403195f0
...
...
@@ -13,7 +13,7 @@ import * as component from '../../common/component';
import
{
NNIError
,
NNIErrorNames
,
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
getBasePort
,
getExperimentId
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobMetric
,
TrialJobStatus
,
LogType
}
from
'
../../common/trainingService
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobMetric
,
TrialJobStatus
}
from
'
../../common/trainingService
'
;
import
{
delay
,
getExperimentRootDir
,
getIPV4Address
,
getLogLevel
,
getVersion
,
mkDirPSync
,
randomSelect
,
uniqueString
}
from
'
../../common/utils
'
;
import
{
ExperimentConfig
,
SharedStorageConfig
}
from
'
../../common/experimentConfig
'
;
import
{
GPU_INFO
,
INITIALIZED
,
KILL_TRIAL_JOB
,
NEW_TRIAL_JOB
,
REPORT_METRIC_DATA
,
SEND_TRIAL_JOB_PARAMETER
,
STDOUT
,
TRIAL_END
,
VERSION_CHECK
}
from
'
../../core/commands
'
;
...
...
@@ -157,7 +157,7 @@ class TrialDispatcher implements TrainingService {
return
trial
;
}
public
async
getTrial
Log
(
_trialJobId
:
string
,
_
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrial
File
(
_trialJobId
:
string
,
_
fileName
:
string
):
Promise
<
string
|
Buffer
>
{
throw
new
MethodNotImplementedError
();
}
...
...
@@ -216,7 +216,7 @@ class TrialDispatcher implements TrainingService {
for
(
const
environmentService
of
this
.
environmentServiceList
)
{
const
runnerSettings
:
RunnerSettings
=
new
RunnerSettings
();
runnerSettings
.
nniManagerIP
=
this
.
config
.
nniManagerIp
===
undefined
?
getIPV4Address
()
:
this
.
config
.
nniManagerIp
;
runnerSettings
.
nniManagerIP
=
this
.
config
.
nniManagerIp
===
undefined
?
await
getIPV4Address
()
:
this
.
config
.
nniManagerIp
;
runnerSettings
.
nniManagerPort
=
getBasePort
()
+
1
;
runnerSettings
.
commandChannel
=
environmentService
.
getCommandChannel
.
channelName
;
runnerSettings
.
enableGpuCollector
=
this
.
enableGpuScheduler
;
...
...
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment