Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
403195f0
"docs/archive_en_US/NAS/NasGuide.md" did not exist on "bc0f8f338ba8a7e42e29cbbf47a0edca8244cfcd"
Unverified
Commit
403195f0
authored
Jul 15, 2021
by
Yuge Zhang
Committed by
GitHub
Jul 15, 2021
Browse files
Merge branch 'master' into nn-meter
parents
99aa8226
a7278d2d
Changes
207
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
330 additions
and
98 deletions
+330
-98
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+66
-1
test/ut/sdk/test_compression_utils.py
test/ut/sdk/test_compression_utils.py
+1
-1
test/ut/sdk/test_model_speedup.py
test/ut/sdk/test_model_speedup.py
+142
-36
test/ut/tools/nnictl/mock/restful_server.py
test/ut/tools/nnictl/mock/restful_server.py
+1
-1
ts/nni_manager/common/experimentConfig.ts
ts/nni_manager/common/experimentConfig.ts
+1
-0
ts/nni_manager/common/manager.ts
ts/nni_manager/common/manager.ts
+2
-2
ts/nni_manager/common/trainingService.ts
ts/nni_manager/common/trainingService.ts
+2
-4
ts/nni_manager/common/utils.ts
ts/nni_manager/common/utils.ts
+10
-2
ts/nni_manager/core/nnimanager.ts
ts/nni_manager/core/nnimanager.ts
+20
-4
ts/nni_manager/core/test/mockedTrainingService.ts
ts/nni_manager/core/test/mockedTrainingService.ts
+2
-2
ts/nni_manager/package.json
ts/nni_manager/package.json
+2
-0
ts/nni_manager/rest_server/nniRestServer.ts
ts/nni_manager/rest_server/nniRestServer.ts
+11
-0
ts/nni_manager/rest_server/restHandler.ts
ts/nni_manager/rest_server/restHandler.ts
+34
-8
ts/nni_manager/rest_server/test/mockedNNIManager.ts
ts/nni_manager/rest_server/test/mockedNNIManager.ts
+2
-2
ts/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts
.../training_service/kubernetes/kubernetesTrainingService.ts
+3
-3
ts/nni_manager/training_service/local/localTrainingService.ts
...ni_manager/training_service/local/localTrainingService.ts
+14
-12
ts/nni_manager/training_service/pai/paiTrainingService.ts
ts/nni_manager/training_service/pai/paiTrainingService.ts
+9
-12
ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
...ng_service/remote_machine/remoteMachineTrainingService.ts
+3
-3
ts/nni_manager/training_service/reusable/routerTrainingService.ts
...anager/training_service/reusable/routerTrainingService.ts
+2
-2
ts/nni_manager/training_service/reusable/trialDispatcher.ts
ts/nni_manager/training_service/reusable/trialDispatcher.ts
+3
-3
No files found.
test/ut/retiarii/test_highlevel_apis.py
View file @
403195f0
...
@@ -5,7 +5,7 @@ from collections import Counter
...
@@ -5,7 +5,7 @@ from collections import Counter
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
nni.retiarii
import
Sampler
,
basic_unit
from
nni.retiarii
import
InvalidMutation
,
Sampler
,
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.execution.python
import
_unpack_if_only_one
from
nni.retiarii.execution.python
import
_unpack_if_only_one
...
@@ -520,6 +520,45 @@ class GraphIR(unittest.TestCase):
...
@@ -520,6 +520,45 @@ class GraphIR(unittest.TestCase):
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
def
test_nasbench201_cell
(
self
):
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
NasBench201Cell
([
lambda
x
,
y
:
nn
.
Linear
(
x
,
y
),
lambda
x
,
y
:
nn
.
Linear
(
x
,
y
,
bias
=
False
)
],
10
,
16
)
def
forward
(
self
,
x
):
return
self
.
cell
(
x
)
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
model
=
raw_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
2
,
10
)).
size
()
==
torch
.
Size
([
2
,
16
]))
def
test_autoactivation
(
self
):
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
act
=
nn
.
AutoActivation
()
def
forward
(
self
,
x
):
return
self
.
act
(
x
)
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
model
=
raw_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
2
,
10
)).
size
()
==
torch
.
Size
([
2
,
10
]))
class
Python
(
GraphIR
):
class
Python
(
GraphIR
):
def
_get_converted_pytorch_model
(
self
,
model_ir
):
def
_get_converted_pytorch_model
(
self
,
model_ir
):
...
@@ -545,3 +584,29 @@ class Python(GraphIR):
...
@@ -545,3 +584,29 @@ class Python(GraphIR):
@
unittest
.
skip
@
unittest
.
skip
def
test_valuechoice_access_functional_expression
(
self
):
...
def
test_valuechoice_access_functional_expression
(
self
):
...
def
test_nasbench101_cell
(
self
):
# this is only supported in python engine for now.
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
NasBench101Cell
([
lambda
x
:
nn
.
Linear
(
x
,
x
),
lambda
x
:
nn
.
Linear
(
x
,
x
,
bias
=
False
)],
10
,
16
,
lambda
x
,
y
:
nn
.
Linear
(
x
,
y
),
max_num_nodes
=
5
,
max_num_edges
=
7
)
def
forward
(
self
,
x
):
return
self
.
cell
(
x
)
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
succeeded
=
0
sampler
=
RandomSampler
()
while
succeeded
<=
10
:
try
:
model
=
raw_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
succeeded
+=
1
except
InvalidMutation
:
continue
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
2
,
10
)).
size
()
==
torch
.
Size
([
2
,
16
]))
test/ut/sdk/test_compression_utils.py
View file @
403195f0
...
@@ -116,7 +116,7 @@ class AnalysisUtilsTest(TestCase):
...
@@ -116,7 +116,7 @@ class AnalysisUtilsTest(TestCase):
pruner
.
export_model
(
ck_file
,
mask_file
)
pruner
.
export_model
(
ck_file
,
mask_file
)
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
# Fix the mask conflict
# Fix the mask conflict
fixed_mask
,
_
=
fix_mask_conflict
(
mask_file
,
net
,
dummy_input
)
fixed_mask
=
fix_mask_conflict
(
mask_file
,
net
,
dummy_input
)
# use the channel dependency groud truth to check if
# use the channel dependency groud truth to check if
# fix the mask conflict successfully
# fix the mask conflict successfully
...
...
test/ut/sdk/test_model_speedup.py
View file @
403195f0
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
os
import
os
import
gc
import
psutil
import
psutil
import
sys
import
sys
import
numpy
as
np
import
numpy
as
np
...
@@ -9,18 +11,20 @@ import torch
...
@@ -9,18 +11,20 @@ import torch
import
torchvision.models
as
models
import
torchvision.models
as
models
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torchvision.models.vgg
import
vgg16
from
torchvision.models.vgg
import
vgg16
,
vgg11
from
torchvision.models.resnet
import
resnet18
from
torchvision.models.resnet
import
resnet18
from
torchvision.models.mobilenet
import
mobilenet_v2
import
unittest
import
unittest
from
unittest
import
TestCase
,
main
from
unittest
import
TestCase
,
main
from
nni.compression.pytorch
import
ModelSpeedup
,
apply_compression_results
from
nni.compression.pytorch
import
ModelSpeedup
,
apply_compression_results
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
,
LevelPruner
from
nni.algorithms.compression.pytorch.pruning.weight_masker
import
WeightMasker
from
nni.algorithms.compression.pytorch.pruning.weight_masker
import
WeightMasker
from
nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner
import
DependencyAwarePruner
from
nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner
import
DependencyAwarePruner
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
BATCH_SIZE
=
2
BATCH_SIZE
=
2
# the relative distance
# the relative distance
RELATIVE_THRESHOLD
=
0.01
RELATIVE_THRESHOLD
=
0.01
...
@@ -105,6 +109,55 @@ class TransposeModel(torch.nn.Module):
...
@@ -105,6 +109,55 @@ class TransposeModel(torch.nn.Module):
return
x
return
x
class
TupleUnpack_backbone
(
nn
.
Module
):
def
__init__
(
self
,
width
):
super
(
TupleUnpack_backbone
,
self
).
__init__
()
self
.
model_backbone
=
mobilenet_v2
(
pretrained
=
False
,
width_mult
=
width
,
num_classes
=
3
)
def
forward
(
self
,
x
):
x1
=
self
.
model_backbone
.
features
[:
7
](
x
)
x2
=
self
.
model_backbone
.
features
[
7
:
14
](
x1
)
x3
=
self
.
model_backbone
.
features
[
14
:
18
](
x2
)
return
[
x1
,
x2
,
x3
]
class
TupleUnpack_FPN
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TupleUnpack_FPN
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
32
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
96
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
self
.
conv3
=
nn
.
Conv2d
(
320
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
# self.init_weights()
def
forward
(
self
,
inputs
):
"""Forward function."""
laterals
=
[]
laterals
.
append
(
self
.
conv1
(
inputs
[
0
]))
# inputs[0]==x1
laterals
.
append
(
self
.
conv2
(
inputs
[
1
]))
# inputs[1]==x2
laterals
.
append
(
self
.
conv3
(
inputs
[
2
]))
# inputs[2]==x3
return
laterals
class
TupleUnpack_Model
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TupleUnpack_Model
,
self
).
__init__
()
self
.
backbone
=
TupleUnpack_backbone
(
1.0
)
self
.
fpn
=
TupleUnpack_FPN
()
def
forward
(
self
,
x
):
x1
=
self
.
backbone
(
x
)
out
=
self
.
fpn
(
x1
)
return
out
dummy_input
=
torch
.
randn
(
2
,
1
,
28
,
28
)
dummy_input
=
torch
.
randn
(
2
,
1
,
28
,
28
)
SPARSITY
=
0.5
SPARSITY
=
0.5
MODEL_FILE
,
MASK_FILE
=
'./11_model.pth'
,
'./l1_mask.pth'
MODEL_FILE
,
MASK_FILE
=
'./11_model.pth'
,
'./l1_mask.pth'
...
@@ -129,6 +182,7 @@ def generate_random_sparsity(model):
...
@@ -129,6 +182,7 @@ def generate_random_sparsity(model):
'sparsity'
:
sparsity
})
'sparsity'
:
sparsity
})
return
cfg_list
return
cfg_list
def
generate_random_sparsity_v2
(
model
):
def
generate_random_sparsity_v2
(
model
):
"""
"""
Only select 50% layers to prune.
Only select 50% layers to prune.
...
@@ -139,9 +193,10 @@ def generate_random_sparsity_v2(model):
...
@@ -139,9 +193,10 @@ def generate_random_sparsity_v2(model):
if
np
.
random
.
uniform
(
0
,
1.0
)
>
0.5
:
if
np
.
random
.
uniform
(
0
,
1.0
)
>
0.5
:
sparsity
=
np
.
random
.
uniform
(
0.5
,
0.99
)
sparsity
=
np
.
random
.
uniform
(
0.5
,
0.99
)
cfg_list
.
append
({
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
name
],
cfg_list
.
append
({
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
name
],
'sparsity'
:
sparsity
})
'sparsity'
:
sparsity
})
return
cfg_list
return
cfg_list
def
zero_bn_bias
(
model
):
def
zero_bn_bias
(
model
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
...
@@ -231,19 +286,6 @@ def channel_prune(model):
...
@@ -231,19 +286,6 @@ def channel_prune(model):
class
SpeedupTestCase
(
TestCase
):
class
SpeedupTestCase
(
TestCase
):
def
test_speedup_vgg16
(
self
):
prune_model_l1
(
vgg16
())
model
=
vgg16
()
model
.
train
()
ms
=
ModelSpeedup
(
model
,
torch
.
randn
(
2
,
3
,
32
,
32
),
MASK_FILE
)
ms
.
speedup_model
()
orig_model
=
vgg16
()
assert
model
.
training
assert
model
.
features
[
2
].
out_channels
==
int
(
orig_model
.
features
[
2
].
out_channels
*
SPARSITY
)
assert
model
.
classifier
[
0
].
in_features
==
int
(
orig_model
.
classifier
[
0
].
in_features
*
SPARSITY
)
def
test_speedup_bigmodel
(
self
):
def
test_speedup_bigmodel
(
self
):
prune_model_l1
(
BigModel
())
prune_model_l1
(
BigModel
())
...
@@ -253,7 +295,7 @@ class SpeedupTestCase(TestCase):
...
@@ -253,7 +295,7 @@ class SpeedupTestCase(TestCase):
mask_out
=
model
(
dummy_input
)
mask_out
=
model
(
dummy_input
)
model
.
train
()
model
.
train
()
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
)
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
ms
.
speedup_model
()
assert
model
.
training
assert
model
.
training
...
@@ -289,7 +331,7 @@ class SpeedupTestCase(TestCase):
...
@@ -289,7 +331,7 @@ class SpeedupTestCase(TestCase):
new_model
=
TransposeModel
()
new_model
=
TransposeModel
()
state_dict
=
torch
.
load
(
MODEL_FILE
)
state_dict
=
torch
.
load
(
MODEL_FILE
)
new_model
.
load_state_dict
(
state_dict
)
new_model
.
load_state_dict
(
state_dict
)
ms
=
ModelSpeedup
(
new_model
,
dummy_input
,
MASK_FILE
)
ms
=
ModelSpeedup
(
new_model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
ms
.
speedup_model
()
zero_bn_bias
(
ori_model
)
zero_bn_bias
(
ori_model
)
zero_bn_bias
(
new_model
)
zero_bn_bias
(
new_model
)
...
@@ -297,26 +339,38 @@ class SpeedupTestCase(TestCase):
...
@@ -297,26 +339,38 @@ class SpeedupTestCase(TestCase):
new_out
=
new_model
(
dummy_input
)
new_out
=
new_model
(
dummy_input
)
ori_sum
=
torch
.
sum
(
ori_out
)
ori_sum
=
torch
.
sum
(
ori_out
)
speeded_sum
=
torch
.
sum
(
new_out
)
speeded_sum
=
torch
.
sum
(
new_out
)
print
(
'Tanspose Speedup Test: ori_sum={} speedup_sum={}'
.
format
(
ori_sum
,
speeded_sum
))
print
(
'Tanspose Speedup Test: ori_sum={} speedup_sum={}'
.
format
(
ori_sum
,
speeded_sum
))
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
# FIXME: This test case might fail randomly, no idea why
def
test_speedup_integration_small
(
self
):
# Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282
model_list
=
[
'resnet18'
,
'mobilenet_v2'
,
'alexnet'
]
self
.
speedup_integration
(
model_list
)
def
test_speedup_integration_big
(
self
):
model_list
=
[
'vgg11'
,
'vgg16'
,
'resnet34'
,
'squeezenet1_1'
,
'densenet121'
,
'resnet50'
,
'wide_resnet50_2'
]
mem_info
=
psutil
.
virtual_memory
()
ava_gb
=
mem_info
.
available
/
1024.0
/
1024
/
1024
print
(
'Avaliable memory size: %.2f GB'
%
ava_gb
)
if
ava_gb
<
8.0
:
# memory size is too small that we may run into an OOM exception
# Skip this test in the pipeline test due to memory limitation
return
self
.
speedup_integration
(
model_list
)
def
test_speedup_integration
(
self
):
def
speedup_integration
(
self
,
model_list
,
speedup_cfg
=
None
):
# skip this test on windows(7GB mem available) due to memory limit
# Note: hack trick, may be updated in the future
# Note: hack trick, may be updated in the future
if
'win'
in
sys
.
platform
or
'Win'
in
sys
.
platform
:
if
'win'
in
sys
.
platform
or
'Win'
in
sys
.
platform
:
print
(
'Skip test_speedup_integration on windows due to memory limit!'
)
print
(
'Skip test_speedup_integration on windows due to memory limit!'
)
return
return
Gen_cfg_funcs
=
[
generate_random_sparsity
,
generate_random_sparsity_v2
]
Gen_cfg_funcs
=
[
generate_random_sparsity
,
generate_random_sparsity_v2
]
for
model_name
in
[
'resnet18'
,
'mobilenet_v2'
,
'squeezenet1_1'
,
'densenet121'
,
'densenet169'
,
#
for model_name in [
'vgg16',
'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121'
,
# 'inception_v3' inception is too large and may fail the pipeline
#
# 'inception_v3' inception is too large and may fail the pipeline
'resnet50'
]:
#
'resnet50']:
for
model_name
in
model_list
:
for
gen_cfg_func
in
Gen_cfg_funcs
:
for
gen_cfg_func
in
Gen_cfg_funcs
:
kwargs
=
{
kwargs
=
{
'pretrained'
:
True
'pretrained'
:
True
...
@@ -334,7 +388,10 @@ class SpeedupTestCase(TestCase):
...
@@ -334,7 +388,10 @@ class SpeedupTestCase(TestCase):
speedup_model
.
eval
()
speedup_model
.
eval
()
# random generate the prune config for the pruner
# random generate the prune config for the pruner
cfgs
=
gen_cfg_func
(
net
)
cfgs
=
gen_cfg_func
(
net
)
print
(
"Testing {} with compression config
\n
{}"
.
format
(
model_name
,
cfgs
))
print
(
"Testing {} with compression config
\n
{}"
.
format
(
model_name
,
cfgs
))
if
len
(
cfgs
)
==
0
:
continue
pruner
=
L1FilterPruner
(
net
,
cfgs
)
pruner
=
L1FilterPruner
(
net
,
cfgs
)
pruner
.
compress
()
pruner
.
compress
()
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
...
@@ -345,7 +402,10 @@ class SpeedupTestCase(TestCase):
...
@@ -345,7 +402,10 @@ class SpeedupTestCase(TestCase):
zero_bn_bias
(
speedup_model
)
zero_bn_bias
(
speedup_model
)
data
=
torch
.
ones
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
data
=
torch
.
ones
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
ms
=
ModelSpeedup
(
speedup_model
,
data
,
MASK_FILE
)
if
speedup_cfg
is
None
:
speedup_cfg
=
{}
ms
=
ModelSpeedup
(
speedup_model
,
data
,
MASK_FILE
,
confidence
=
2
,
**
speedup_cfg
)
ms
.
speedup_model
()
ms
.
speedup_model
()
speedup_model
.
eval
()
speedup_model
.
eval
()
...
@@ -355,12 +415,13 @@ class SpeedupTestCase(TestCase):
...
@@ -355,12 +415,13 @@ class SpeedupTestCase(TestCase):
ori_sum
=
torch
.
sum
(
ori_out
).
item
()
ori_sum
=
torch
.
sum
(
ori_out
).
item
()
speeded_sum
=
torch
.
sum
(
speeded_out
).
item
()
speeded_sum
=
torch
.
sum
(
speeded_out
).
item
()
print
(
'Sum of the output of %s (before speedup):'
%
print
(
'Sum of the output of %s (before speedup):'
%
model_name
,
ori_sum
)
model_name
,
ori_sum
)
print
(
'Sum of the output of %s (after speedup):'
%
print
(
'Sum of the output of %s (after
speedup):'
%
model_name
,
speeded_sum
)
model_name
,
speeded_sum
)
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
print
(
"Collecting Garbage"
)
gc
.
collect
(
2
)
def
test_channel_prune
(
self
):
def
test_channel_prune
(
self
):
orig_net
=
resnet18
(
num_classes
=
10
).
to
(
device
)
orig_net
=
resnet18
(
num_classes
=
10
).
to
(
device
)
...
@@ -378,7 +439,7 @@ class SpeedupTestCase(TestCase):
...
@@ -378,7 +439,7 @@ class SpeedupTestCase(TestCase):
net
.
eval
()
net
.
eval
()
data
=
torch
.
randn
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
data
=
torch
.
randn
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
ms
=
ModelSpeedup
(
net
,
data
,
MASK_FILE
)
ms
=
ModelSpeedup
(
net
,
data
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
ms
.
speedup_model
()
ms
.
bound_model
(
data
)
ms
.
bound_model
(
data
)
...
@@ -391,11 +452,56 @@ class SpeedupTestCase(TestCase):
...
@@ -391,11 +452,56 @@ class SpeedupTestCase(TestCase):
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
def
test_speedup_tupleunpack
(
self
):
"""This test is reported in issue3645"""
model
=
TupleUnpack_Model
()
cfg_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.5
}]
dummy_input
=
torch
.
rand
(
2
,
3
,
224
,
224
)
pruner
=
L1FilterPruner
(
model
,
cfg_list
)
pruner
.
compress
()
model
(
dummy_input
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
def
test_finegrained_speedup
(
self
):
""" Test the speedup on the fine-grained sparsity"""
class
MLP
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MLP
,
self
).
__init__
()
self
.
fc1
=
nn
.
Linear
(
1024
,
1024
)
self
.
fc2
=
nn
.
Linear
(
1024
,
1024
)
self
.
fc3
=
nn
.
Linear
(
1024
,
512
)
self
.
fc4
=
nn
.
Linear
(
512
,
10
)
def
forward
(
self
,
x
):
x
=
x
.
view
(
-
1
,
1024
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc3
(
x
)
x
=
self
.
fc4
(
x
)
return
x
model
=
MLP
().
to
(
device
)
dummy_input
=
torch
.
rand
(
16
,
1
,
32
,
32
).
to
(
device
)
cfg_list
=
[{
'op_types'
:
[
'Linear'
],
'sparsity'
:
0.99
}]
pruner
=
LevelPruner
(
model
,
cfg_list
)
pruner
.
compress
()
print
(
'Original Arch'
)
print
(
model
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
_unwrap_model
()
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
4
)
ms
.
speedup_model
()
print
(
"Fine-grained speeduped model"
)
print
(
model
)
def
tearDown
(
self
):
def
tearDown
(
self
):
if
os
.
path
.
exists
(
MODEL_FILE
):
if
os
.
path
.
exists
(
MODEL_FILE
):
os
.
remove
(
MODEL_FILE
)
os
.
remove
(
MODEL_FILE
)
if
os
.
path
.
exists
(
MASK_FILE
):
if
os
.
path
.
exists
(
MASK_FILE
):
os
.
remove
(
MASK_FILE
)
os
.
remove
(
MASK_FILE
)
# GC to release memory
gc
.
collect
(
2
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
test/ut/tools/nnictl/mock/restful_server.py
View file @
403195f0
...
@@ -156,7 +156,7 @@ def mock_get_latest_metric_data():
...
@@ -156,7 +156,7 @@ def mock_get_latest_metric_data():
def
mock_get_trial_log
():
def
mock_get_trial_log
():
responses
.
add
(
responses
.
add
(
responses
.
DELETE
,
'http://localhost:8080/api/v1/nni/trial-
log
/:id/:
typ
e'
,
responses
.
DELETE
,
'http://localhost:8080/api/v1/nni/trial-
file
/:id/:
filenam
e'
,
json
=
{
"status"
:
"RUNNING"
,
"errors"
:[]},
json
=
{
"status"
:
"RUNNING"
,
"errors"
:[]},
status
=
200
,
status
=
200
,
content_type
=
'application/json'
,
content_type
=
'application/json'
,
...
...
ts/nni_manager/common/experimentConfig.ts
View file @
403195f0
...
@@ -161,6 +161,7 @@ export interface ExperimentConfig {
...
@@ -161,6 +161,7 @@ export interface ExperimentConfig {
trialConcurrency
:
number
;
trialConcurrency
:
number
;
trialGpuNumber
?:
number
;
trialGpuNumber
?:
number
;
maxExperimentDuration
?:
string
;
maxExperimentDuration
?:
string
;
maxTrialDuration
?:
string
;
maxTrialNumber
?:
number
;
maxTrialNumber
?:
number
;
nniManagerIp
?:
string
;
nniManagerIp
?:
string
;
//useAnnotation: boolean; // dealed inside nnictl
//useAnnotation: boolean; // dealed inside nnictl
...
...
ts/nni_manager/common/manager.ts
View file @
403195f0
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
'
use strict
'
;
'
use strict
'
;
import
{
MetricDataRecord
,
MetricType
,
TrialJobInfo
}
from
'
./datastore
'
;
import
{
MetricDataRecord
,
MetricType
,
TrialJobInfo
}
from
'
./datastore
'
;
import
{
TrialJobStatus
,
LogType
}
from
'
./trainingService
'
;
import
{
TrialJobStatus
}
from
'
./trainingService
'
;
import
{
ExperimentConfig
}
from
'
./experimentConfig
'
;
import
{
ExperimentConfig
}
from
'
./experimentConfig
'
;
type
ProfileUpdateType
=
'
TRIAL_CONCURRENCY
'
|
'
MAX_EXEC_DURATION
'
|
'
SEARCH_SPACE
'
|
'
MAX_TRIAL_NUM
'
;
type
ProfileUpdateType
=
'
TRIAL_CONCURRENCY
'
|
'
MAX_EXEC_DURATION
'
|
'
SEARCH_SPACE
'
|
'
MAX_TRIAL_NUM
'
;
...
@@ -59,7 +59,7 @@ abstract class Manager {
...
@@ -59,7 +59,7 @@ abstract class Manager {
public
abstract
getMetricDataByRange
(
minSeqId
:
number
,
maxSeqId
:
number
):
Promise
<
MetricDataRecord
[]
>
;
public
abstract
getMetricDataByRange
(
minSeqId
:
number
,
maxSeqId
:
number
):
Promise
<
MetricDataRecord
[]
>
;
public
abstract
getLatestMetricData
():
Promise
<
MetricDataRecord
[]
>
;
public
abstract
getLatestMetricData
():
Promise
<
MetricDataRecord
[]
>
;
public
abstract
getTrial
Log
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
;
public
abstract
getTrial
File
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
Buffer
|
string
>
;
public
abstract
getTrialJobStatistics
():
Promise
<
TrialJobStatistics
[]
>
;
public
abstract
getTrialJobStatistics
():
Promise
<
TrialJobStatistics
[]
>
;
public
abstract
getStatus
():
NNIManagerStatus
;
public
abstract
getStatus
():
NNIManagerStatus
;
...
...
ts/nni_manager/common/trainingService.ts
View file @
403195f0
...
@@ -8,8 +8,6 @@
...
@@ -8,8 +8,6 @@
*/
*/
type
TrialJobStatus
=
'
UNKNOWN
'
|
'
WAITING
'
|
'
RUNNING
'
|
'
SUCCEEDED
'
|
'
FAILED
'
|
'
USER_CANCELED
'
|
'
SYS_CANCELED
'
|
'
EARLY_STOPPED
'
;
type
TrialJobStatus
=
'
UNKNOWN
'
|
'
WAITING
'
|
'
RUNNING
'
|
'
SUCCEEDED
'
|
'
FAILED
'
|
'
USER_CANCELED
'
|
'
SYS_CANCELED
'
|
'
EARLY_STOPPED
'
;
type
LogType
=
'
TRIAL_LOG
'
|
'
TRIAL_STDOUT
'
|
'
TRIAL_ERROR
'
;
interface
TrainingServiceMetadata
{
interface
TrainingServiceMetadata
{
readonly
key
:
string
;
readonly
key
:
string
;
readonly
value
:
string
;
readonly
value
:
string
;
...
@@ -81,7 +79,7 @@ abstract class TrainingService {
...
@@ -81,7 +79,7 @@ abstract class TrainingService {
public
abstract
submitTrialJob
(
form
:
TrialJobApplicationForm
):
Promise
<
TrialJobDetail
>
;
public
abstract
submitTrialJob
(
form
:
TrialJobApplicationForm
):
Promise
<
TrialJobDetail
>
;
public
abstract
updateTrialJob
(
trialJobId
:
string
,
form
:
TrialJobApplicationForm
):
Promise
<
TrialJobDetail
>
;
public
abstract
updateTrialJob
(
trialJobId
:
string
,
form
:
TrialJobApplicationForm
):
Promise
<
TrialJobDetail
>
;
public
abstract
cancelTrialJob
(
trialJobId
:
string
,
isEarlyStopped
?:
boolean
):
Promise
<
void
>
;
public
abstract
cancelTrialJob
(
trialJobId
:
string
,
isEarlyStopped
?:
boolean
):
Promise
<
void
>
;
public
abstract
getTrial
Log
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
;
public
abstract
getTrial
File
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
Buffer
|
string
>
;
public
abstract
setClusterMetadata
(
key
:
string
,
value
:
string
):
Promise
<
void
>
;
public
abstract
setClusterMetadata
(
key
:
string
,
value
:
string
):
Promise
<
void
>
;
public
abstract
getClusterMetadata
(
key
:
string
):
Promise
<
string
>
;
public
abstract
getClusterMetadata
(
key
:
string
):
Promise
<
string
>
;
public
abstract
getTrialOutputLocalPath
(
trialJobId
:
string
):
Promise
<
string
>
;
public
abstract
getTrialOutputLocalPath
(
trialJobId
:
string
):
Promise
<
string
>
;
...
@@ -103,5 +101,5 @@ class NNIManagerIpConfig {
...
@@ -103,5 +101,5 @@ class NNIManagerIpConfig {
export
{
export
{
TrainingService
,
TrainingServiceError
,
TrialJobStatus
,
TrialJobApplicationForm
,
TrainingService
,
TrainingServiceError
,
TrialJobStatus
,
TrialJobApplicationForm
,
TrainingServiceMetadata
,
TrialJobDetail
,
TrialJobMetric
,
HyperParameters
,
TrainingServiceMetadata
,
TrialJobDetail
,
TrialJobMetric
,
HyperParameters
,
NNIManagerIpConfig
,
LogType
NNIManagerIpConfig
};
};
ts/nni_manager/common/utils.ts
View file @
403195f0
...
@@ -223,7 +223,7 @@ let cachedIpv4Address: string | null = null;
...
@@ -223,7 +223,7 @@ let cachedIpv4Address: string | null = null;
/**
/**
* Get IPv4 address of current machine.
* Get IPv4 address of current machine.
*/
*/
function
getIPV4Address
():
string
{
async
function
getIPV4Address
():
Promise
<
string
>
{
if
(
cachedIpv4Address
!==
null
)
{
if
(
cachedIpv4Address
!==
null
)
{
return
cachedIpv4Address
;
return
cachedIpv4Address
;
}
}
...
@@ -232,12 +232,20 @@ function getIPV4Address(): string {
...
@@ -232,12 +232,20 @@ function getIPV4Address(): string {
// since udp is connectionless, this does not send actual packets.
// since udp is connectionless, this does not send actual packets.
const
socket
=
dgram
.
createSocket
(
'
udp4
'
);
const
socket
=
dgram
.
createSocket
(
'
udp4
'
);
socket
.
connect
(
1
,
'
192.0.2.0
'
);
socket
.
connect
(
1
,
'
192.0.2.0
'
);
cachedIpv4Address
=
socket
.
address
().
address
;
for
(
let
i
=
0
;
i
<
10
;
i
++
)
{
// wait the system to initialize "connection"
await
yield_
();
try
{
cachedIpv4Address
=
socket
.
address
().
address
;
}
catch
(
error
)
{
/* retry */
}
}
cachedIpv4Address
=
socket
.
address
().
address
;
// if it still fails, throw the error
socket
.
close
();
socket
.
close
();
return
cachedIpv4Address
;
return
cachedIpv4Address
;
}
}
async
function
yield_
():
Promise
<
void
>
{
/* trigger the scheduler, do nothing */
}
/**
/**
* Get the status of canceled jobs according to the hint isEarlyStopped
* Get the status of canceled jobs according to the hint isEarlyStopped
*/
*/
...
...
ts/nni_manager/core/nnimanager.ts
View file @
403195f0
...
@@ -19,7 +19,7 @@ import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/exp
...
@@ -19,7 +19,7 @@ import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/exp
import
{
ExperimentManager
}
from
'
../common/experimentManager
'
;
import
{
ExperimentManager
}
from
'
../common/experimentManager
'
;
import
{
TensorboardManager
}
from
'
../common/tensorboardManager
'
;
import
{
TensorboardManager
}
from
'
../common/tensorboardManager
'
;
import
{
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
TrialJobStatus
,
LogType
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
TrialJobStatus
}
from
'
../common/trainingService
'
;
}
from
'
../common/trainingService
'
;
import
{
delay
,
getCheckpointDir
,
getExperimentRootDir
,
getLogDir
,
getMsgDispatcherCommand
,
mkDirP
,
getTunerProc
,
getLogLevel
,
isAlive
,
killPid
}
from
'
../common/utils
'
;
import
{
delay
,
getCheckpointDir
,
getExperimentRootDir
,
getLogDir
,
getMsgDispatcherCommand
,
mkDirP
,
getTunerProc
,
getLogLevel
,
isAlive
,
killPid
}
from
'
../common/utils
'
;
import
{
import
{
...
@@ -189,7 +189,6 @@ class NNIManager implements Manager {
...
@@ -189,7 +189,6 @@ class NNIManager implements Manager {
this
.
log
.
debug
(
`dispatcher command:
${
dispatcherCommand
}
`
);
this
.
log
.
debug
(
`dispatcher command:
${
dispatcherCommand
}
`
);
const
checkpointDir
:
string
=
await
this
.
createCheckpointDir
();
const
checkpointDir
:
string
=
await
this
.
createCheckpointDir
();
this
.
setupTuner
(
dispatcherCommand
,
undefined
,
'
start
'
,
checkpointDir
);
this
.
setupTuner
(
dispatcherCommand
,
undefined
,
'
start
'
,
checkpointDir
);
this
.
setStatus
(
'
RUNNING
'
);
this
.
setStatus
(
'
RUNNING
'
);
await
this
.
storeExperimentProfile
();
await
this
.
storeExperimentProfile
();
this
.
run
().
catch
((
err
:
Error
)
=>
{
this
.
run
().
catch
((
err
:
Error
)
=>
{
...
@@ -403,8 +402,8 @@ class NNIManager implements Manager {
...
@@ -403,8 +402,8 @@ class NNIManager implements Manager {
// FIXME: unit test
// FIXME: unit test
}
}
public
async
getTrial
Log
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrial
File
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
Buffer
|
string
>
{
return
this
.
trainingService
.
getTrial
Log
(
trialJobId
,
logTyp
e
);
return
this
.
trainingService
.
getTrial
File
(
trialJobId
,
fileNam
e
);
}
}
public
getExperimentProfile
():
Promise
<
ExperimentProfile
>
{
public
getExperimentProfile
():
Promise
<
ExperimentProfile
>
{
...
@@ -433,6 +432,11 @@ class NNIManager implements Manager {
...
@@ -433,6 +432,11 @@ class NNIManager implements Manager {
return
(
value
===
undefined
?
Infinity
:
value
);
return
(
value
===
undefined
?
Infinity
:
value
);
}
}
private
get
maxTrialDuration
():
number
{
const
value
=
this
.
experimentProfile
.
params
.
maxTrialDuration
;
return
(
value
===
undefined
?
Infinity
:
toSeconds
(
value
));
}
private
async
initTrainingService
(
config
:
ExperimentConfig
):
Promise
<
TrainingService
>
{
private
async
initTrainingService
(
config
:
ExperimentConfig
):
Promise
<
TrainingService
>
{
let
platform
:
string
;
let
platform
:
string
;
if
(
Array
.
isArray
(
config
.
trainingService
))
{
if
(
Array
.
isArray
(
config
.
trainingService
))
{
...
@@ -539,6 +543,17 @@ class NNIManager implements Manager {
...
@@ -539,6 +543,17 @@ class NNIManager implements Manager {
}
}
}
}
private
async
stopTrialJobIfOverMaxDurationTimer
(
trialJobId
:
string
):
Promise
<
void
>
{
const
trialJobDetail
:
TrialJobDetail
|
undefined
=
this
.
trialJobs
.
get
(
trialJobId
);
if
(
undefined
!==
trialJobDetail
&&
trialJobDetail
.
status
===
'
RUNNING
'
&&
trialJobDetail
.
startTime
!==
undefined
){
const
isEarlyStopped
=
true
;
await
this
.
trainingService
.
cancelTrialJob
(
trialJobId
,
isEarlyStopped
);
this
.
log
.
info
(
`Trial job
${
trialJobId
}
has stoped because it is over maxTrialDuration.`
);
}
}
private
async
requestTrialJobsStatus
():
Promise
<
number
>
{
private
async
requestTrialJobsStatus
():
Promise
<
number
>
{
let
finishedTrialJobNum
:
number
=
0
;
let
finishedTrialJobNum
:
number
=
0
;
if
(
this
.
dispatcher
===
undefined
)
{
if
(
this
.
dispatcher
===
undefined
)
{
...
@@ -662,6 +677,7 @@ class NNIManager implements Manager {
...
@@ -662,6 +677,7 @@ class NNIManager implements Manager {
this
.
currSubmittedTrialNum
++
;
this
.
currSubmittedTrialNum
++
;
this
.
log
.
info
(
'
submitTrialJob: form:
'
,
form
);
this
.
log
.
info
(
'
submitTrialJob: form:
'
,
form
);
const
trialJobDetail
:
TrialJobDetail
=
await
this
.
trainingService
.
submitTrialJob
(
form
);
const
trialJobDetail
:
TrialJobDetail
=
await
this
.
trainingService
.
submitTrialJob
(
form
);
setTimeout
(
async
()
=>
this
.
stopTrialJobIfOverMaxDurationTimer
(
trialJobDetail
.
id
),
1000
*
this
.
maxTrialDuration
);
const
Snapshot
:
TrialJobDetail
=
Object
.
assign
({},
trialJobDetail
);
const
Snapshot
:
TrialJobDetail
=
Object
.
assign
({},
trialJobDetail
);
await
this
.
storeExperimentProfile
();
await
this
.
storeExperimentProfile
();
this
.
trialJobs
.
set
(
trialJobDetail
.
id
,
Snapshot
);
this
.
trialJobs
.
set
(
trialJobDetail
.
id
,
Snapshot
);
...
...
ts/nni_manager/core/test/mockedTrainingService.ts
View file @
403195f0
...
@@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred';
...
@@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred';
import
{
Provider
}
from
'
typescript-ioc
'
;
import
{
Provider
}
from
'
typescript-ioc
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
LogType
}
from
'
../../common/trainingService
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
const
testTrainingServiceProvider
:
Provider
=
{
const
testTrainingServiceProvider
:
Provider
=
{
get
:
()
=>
{
return
new
MockedTrainingService
();
}
get
:
()
=>
{
return
new
MockedTrainingService
();
}
...
@@ -63,7 +63,7 @@ class MockedTrainingService extends TrainingService {
...
@@ -63,7 +63,7 @@ class MockedTrainingService extends TrainingService {
return
deferred
.
promise
;
return
deferred
.
promise
;
}
}
public
getTrial
Log
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
{
public
getTrial
File
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
string
>
{
throw
new
MethodNotImplementedError
();
throw
new
MethodNotImplementedError
();
}
}
...
...
ts/nni_manager/package.json
View file @
403195f0
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"child-process-promise"
:
"^2.2.1"
,
"child-process-promise"
:
"^2.2.1"
,
"express"
:
"^4.17.1"
,
"express"
:
"^4.17.1"
,
"express-joi-validator"
:
"^2.0.1"
,
"express-joi-validator"
:
"^2.0.1"
,
"http-proxy"
:
"^1.18.1"
,
"ignore"
:
"^5.1.8"
,
"ignore"
:
"^5.1.8"
,
"js-base64"
:
"^3.6.1"
,
"js-base64"
:
"^3.6.1"
,
"kubernetes-client"
:
"^6.12.1"
,
"kubernetes-client"
:
"^6.12.1"
,
...
@@ -37,6 +38,7 @@
...
@@ -37,6 +38,7 @@
"@types/chai-as-promised"
:
"^7.1.0"
,
"@types/chai-as-promised"
:
"^7.1.0"
,
"@types/express"
:
"^4.17.2"
,
"@types/express"
:
"^4.17.2"
,
"@types/glob"
:
"^7.1.3"
,
"@types/glob"
:
"^7.1.3"
,
"@types/http-proxy"
:
"^1.17.7"
,
"@types/js-base64"
:
"^3.3.1"
,
"@types/js-base64"
:
"^3.3.1"
,
"@types/js-yaml"
:
"^4.0.1"
,
"@types/js-yaml"
:
"^4.0.1"
,
"@types/lockfile"
:
"^1.0.0"
,
"@types/lockfile"
:
"^1.0.0"
,
...
...
ts/nni_manager/rest_server/nniRestServer.ts
View file @
403195f0
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
*
as
bodyParser
from
'
body-parser
'
;
import
*
as
bodyParser
from
'
body-parser
'
;
import
*
as
express
from
'
express
'
;
import
*
as
express
from
'
express
'
;
import
*
as
httpProxy
from
'
http-proxy
'
;
import
*
as
path
from
'
path
'
;
import
*
as
path
from
'
path
'
;
import
*
as
component
from
'
../common/component
'
;
import
*
as
component
from
'
../common/component
'
;
import
{
RestServer
}
from
'
../common/restServer
'
import
{
RestServer
}
from
'
../common/restServer
'
...
@@ -21,6 +22,7 @@ import { getAPIRootUrl } from '../common/experimentStartupInfo';
...
@@ -21,6 +22,7 @@ import { getAPIRootUrl } from '../common/experimentStartupInfo';
@
component
.
Singleton
@
component
.
Singleton
export
class
NNIRestServer
extends
RestServer
{
export
class
NNIRestServer
extends
RestServer
{
private
readonly
LOGS_ROOT_URL
:
string
=
'
/logs
'
;
private
readonly
LOGS_ROOT_URL
:
string
=
'
/logs
'
;
protected
netronProxy
:
any
=
null
;
protected
API_ROOT_URL
:
string
=
'
/api/v1/nni
'
;
protected
API_ROOT_URL
:
string
=
'
/api/v1/nni
'
;
/**
/**
...
@@ -29,6 +31,7 @@ export class NNIRestServer extends RestServer {
...
@@ -29,6 +31,7 @@ export class NNIRestServer extends RestServer {
constructor
()
{
constructor
()
{
super
();
super
();
this
.
API_ROOT_URL
=
getAPIRootUrl
();
this
.
API_ROOT_URL
=
getAPIRootUrl
();
this
.
netronProxy
=
httpProxy
.
createProxyServer
();
}
}
/**
/**
...
@@ -39,6 +42,14 @@ export class NNIRestServer extends RestServer {
...
@@ -39,6 +42,14 @@ export class NNIRestServer extends RestServer {
this
.
app
.
use
(
bodyParser
.
json
({
limit
:
'
50mb
'
}));
this
.
app
.
use
(
bodyParser
.
json
({
limit
:
'
50mb
'
}));
this
.
app
.
use
(
this
.
API_ROOT_URL
,
createRestHandler
(
this
));
this
.
app
.
use
(
this
.
API_ROOT_URL
,
createRestHandler
(
this
));
this
.
app
.
use
(
this
.
LOGS_ROOT_URL
,
express
.
static
(
getLogDir
()));
this
.
app
.
use
(
this
.
LOGS_ROOT_URL
,
express
.
static
(
getLogDir
()));
this
.
app
.
all
(
'
/netron/*
'
,
(
req
:
express
.
Request
,
res
:
express
.
Response
)
=>
{
delete
req
.
headers
.
host
;
req
.
url
=
req
.
url
.
replace
(
'
/netron
'
,
'
/
'
);
this
.
netronProxy
.
web
(
req
,
res
,
{
changeOrigin
:
true
,
target
:
'
https://netron.app
'
});
});
this
.
app
.
get
(
'
*
'
,
(
req
:
express
.
Request
,
res
:
express
.
Response
)
=>
{
this
.
app
.
get
(
'
*
'
,
(
req
:
express
.
Request
,
res
:
express
.
Response
)
=>
{
res
.
sendFile
(
path
.
resolve
(
'
static/index.html
'
));
res
.
sendFile
(
path
.
resolve
(
'
static/index.html
'
));
});
});
...
...
ts/nni_manager/rest_server/restHandler.ts
View file @
403195f0
...
@@ -19,7 +19,7 @@ import { NNIRestServer } from './nniRestServer';
...
@@ -19,7 +19,7 @@ import { NNIRestServer } from './nniRestServer';
import
{
getVersion
}
from
'
../common/utils
'
;
import
{
getVersion
}
from
'
../common/utils
'
;
import
{
MetricType
}
from
'
../common/datastore
'
;
import
{
MetricType
}
from
'
../common/datastore
'
;
import
{
ProfileUpdateType
}
from
'
../common/manager
'
;
import
{
ProfileUpdateType
}
from
'
../common/manager
'
;
import
{
LogType
,
TrialJobStatus
}
from
'
../common/trainingService
'
;
import
{
TrialJobStatus
}
from
'
../common/trainingService
'
;
const
expressJoi
=
require
(
'
express-joi-validator
'
);
const
expressJoi
=
require
(
'
express-joi-validator
'
);
...
@@ -53,6 +53,7 @@ class NNIRestHandler {
...
@@ -53,6 +53,7 @@ class NNIRestHandler {
this
.
version
(
router
);
this
.
version
(
router
);
this
.
checkStatus
(
router
);
this
.
checkStatus
(
router
);
this
.
getExperimentProfile
(
router
);
this
.
getExperimentProfile
(
router
);
this
.
getExperimentMetadata
(
router
);
this
.
updateExperimentProfile
(
router
);
this
.
updateExperimentProfile
(
router
);
this
.
importData
(
router
);
this
.
importData
(
router
);
this
.
getImportedData
(
router
);
this
.
getImportedData
(
router
);
...
@@ -66,7 +67,7 @@ class NNIRestHandler {
...
@@ -66,7 +67,7 @@ class NNIRestHandler {
this
.
getMetricData
(
router
);
this
.
getMetricData
(
router
);
this
.
getMetricDataByRange
(
router
);
this
.
getMetricDataByRange
(
router
);
this
.
getLatestMetricData
(
router
);
this
.
getLatestMetricData
(
router
);
this
.
getTrial
Log
(
router
);
this
.
getTrial
File
(
router
);
this
.
exportData
(
router
);
this
.
exportData
(
router
);
this
.
getExperimentsInfo
(
router
);
this
.
getExperimentsInfo
(
router
);
this
.
startTensorboardTask
(
router
);
this
.
startTensorboardTask
(
router
);
...
@@ -296,13 +297,20 @@ class NNIRestHandler {
...
@@ -296,13 +297,20 @@ class NNIRestHandler {
});
});
}
}
private
getTrialLog
(
router
:
Router
):
void
{
private
getTrialFile
(
router
:
Router
):
void
{
router
.
get
(
'
/trial-log/:id/:type
'
,
async
(
req
:
Request
,
res
:
Response
)
=>
{
router
.
get
(
'
/trial-file/:id/:filename
'
,
async
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
nniManager
.
getTrialLog
(
req
.
params
.
id
,
req
.
params
.
type
as
LogType
).
then
((
log
:
string
)
=>
{
let
encoding
:
string
|
null
=
null
;
if
(
log
===
''
)
{
const
filename
=
req
.
params
.
filename
;
log
=
'
No logs available.
'
if
(
!
filename
.
includes
(
'
.
'
)
||
filename
.
match
(
/.*
\.(
txt|log
)
/g
))
{
encoding
=
'
utf8
'
;
}
this
.
nniManager
.
getTrialFile
(
req
.
params
.
id
,
filename
).
then
((
content
:
Buffer
|
string
)
=>
{
if
(
content
instanceof
Buffer
)
{
res
.
header
(
'
Content-Type
'
,
'
application/octet-stream
'
);
}
else
if
(
content
===
''
)
{
content
=
`
${
filename
}
is empty.`
;
}
}
res
.
send
(
log
);
res
.
send
(
content
);
}).
catch
((
err
:
Error
)
=>
{
}).
catch
((
err
:
Error
)
=>
{
this
.
handleError
(
err
,
res
);
this
.
handleError
(
err
,
res
);
});
});
...
@@ -319,6 +327,24 @@ class NNIRestHandler {
...
@@ -319,6 +327,24 @@ class NNIRestHandler {
});
});
}
}
private
getExperimentMetadata
(
router
:
Router
):
void
{
router
.
get
(
'
/experiment-metadata
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
Promise
.
all
([
this
.
nniManager
.
getExperimentProfile
(),
this
.
experimentsManager
.
getExperimentsInfo
()
]).
then
(([
profile
,
experimentInfo
])
=>
{
for
(
const
info
of
experimentInfo
as
any
)
{
if
(
info
.
id
===
profile
.
id
)
{
res
.
send
(
info
);
break
;
}
}
}).
catch
((
err
:
Error
)
=>
{
this
.
handleError
(
err
,
res
);
});
});
}
private
getExperimentsInfo
(
router
:
Router
):
void
{
private
getExperimentsInfo
(
router
:
Router
):
void
{
router
.
get
(
'
/experiments-info
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
router
.
get
(
'
/experiments-info
'
,
(
req
:
Request
,
res
:
Response
)
=>
{
this
.
experimentsManager
.
getExperimentsInfo
().
then
((
experimentInfo
:
JSON
)
=>
{
this
.
experimentsManager
.
getExperimentsInfo
().
then
((
experimentInfo
:
JSON
)
=>
{
...
...
ts/nni_manager/rest_server/test/mockedNNIManager.ts
View file @
403195f0
...
@@ -13,7 +13,7 @@ import {
...
@@ -13,7 +13,7 @@ import {
TrialJobStatistics
,
NNIManagerStatus
TrialJobStatistics
,
NNIManagerStatus
}
from
'
../../common/manager
'
;
}
from
'
../../common/manager
'
;
import
{
import
{
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobStatus
,
LogType
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobStatus
}
from
'
../../common/trainingService
'
;
}
from
'
../../common/trainingService
'
;
export
const
testManagerProvider
:
Provider
=
{
export
const
testManagerProvider
:
Provider
=
{
...
@@ -129,7 +129,7 @@ export class MockedNNIManager extends Manager {
...
@@ -129,7 +129,7 @@ export class MockedNNIManager extends Manager {
public
getLatestMetricData
():
Promise
<
MetricDataRecord
[]
>
{
public
getLatestMetricData
():
Promise
<
MetricDataRecord
[]
>
{
throw
new
MethodNotImplementedError
();
throw
new
MethodNotImplementedError
();
}
}
public
getTrial
Log
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
{
public
getTrial
File
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
string
>
{
throw
new
MethodNotImplementedError
();
throw
new
MethodNotImplementedError
();
}
}
public
getExperimentProfile
():
Promise
<
ExperimentProfile
>
{
public
getExperimentProfile
():
Promise
<
ExperimentProfile
>
{
...
...
ts/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts
View file @
403195f0
...
@@ -14,7 +14,7 @@ import {getExperimentId} from '../../common/experimentStartupInfo';
...
@@ -14,7 +14,7 @@ import {getExperimentId} from '../../common/experimentStartupInfo';
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
import
{
NNIManagerIpConfig
,
TrialJobDetail
,
TrialJobMetric
,
LogType
NNIManagerIpConfig
,
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
}
from
'
../../common/trainingService
'
;
import
{
delay
,
getExperimentRootDir
,
getIPV4Address
,
getJobCancelStatus
,
getVersion
,
uniqueString
}
from
'
../../common/utils
'
;
import
{
delay
,
getExperimentRootDir
,
getIPV4Address
,
getJobCancelStatus
,
getVersion
,
uniqueString
}
from
'
../../common/utils
'
;
import
{
AzureStorageClientUtility
}
from
'
./azureStorageClientUtils
'
;
import
{
AzureStorageClientUtility
}
from
'
./azureStorageClientUtils
'
;
...
@@ -99,7 +99,7 @@ abstract class KubernetesTrainingService {
...
@@ -99,7 +99,7 @@ abstract class KubernetesTrainingService {
return
Promise
.
resolve
(
kubernetesTrialJob
);
return
Promise
.
resolve
(
kubernetesTrialJob
);
}
}
public
async
getTrial
Log
(
_trialJobId
:
string
,
_
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrial
File
(
_trialJobId
:
string
,
_
filename
:
string
):
Promise
<
string
|
Buffer
>
{
throw
new
MethodNotImplementedError
();
throw
new
MethodNotImplementedError
();
}
}
...
@@ -277,7 +277,7 @@ abstract class KubernetesTrainingService {
...
@@ -277,7 +277,7 @@ abstract class KubernetesTrainingService {
if
(
gpuNum
===
0
)
{
if
(
gpuNum
===
0
)
{
nvidiaScript
=
'
export CUDA_VISIBLE_DEVICES=
'
;
nvidiaScript
=
'
export CUDA_VISIBLE_DEVICES=
'
;
}
}
const
nniManagerIp
:
string
=
this
.
nniManagerIpConfig
?
this
.
nniManagerIpConfig
.
nniManagerIp
:
getIPV4Address
();
const
nniManagerIp
:
string
=
this
.
nniManagerIpConfig
?
this
.
nniManagerIpConfig
.
nniManagerIp
:
await
getIPV4Address
();
const
version
:
string
=
this
.
versionCheck
?
await
getVersion
()
:
''
;
const
version
:
string
=
this
.
versionCheck
?
await
getVersion
()
:
''
;
const
runScript
:
string
=
String
.
Format
(
const
runScript
:
string
=
String
.
Format
(
kubernetesScriptFormat
,
kubernetesScriptFormat
,
...
...
ts/nni_manager/training_service/local/localTrainingService.ts
View file @
403195f0
...
@@ -13,7 +13,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo';
...
@@ -13,7 +13,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo';
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
import
{
HyperParameters
,
TrainingService
,
TrialJobApplicationForm
,
HyperParameters
,
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
TrialJobStatus
,
LogType
TrialJobDetail
,
TrialJobMetric
,
TrialJobStatus
}
from
'
../../common/trainingService
'
;
}
from
'
../../common/trainingService
'
;
import
{
import
{
delay
,
generateParamFileName
,
getExperimentRootDir
,
getJobCancelStatus
,
getNewLine
,
isAlive
,
uniqueString
delay
,
generateParamFileName
,
getExperimentRootDir
,
getJobCancelStatus
,
getNewLine
,
isAlive
,
uniqueString
...
@@ -170,18 +170,20 @@ class LocalTrainingService implements TrainingService {
...
@@ -170,18 +170,20 @@ class LocalTrainingService implements TrainingService {
return
trialJob
;
return
trialJob
;
}
}
public
async
getTrialLog
(
trialJobId
:
string
,
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrialFile
(
trialJobId
:
string
,
fileName
:
string
):
Promise
<
string
|
Buffer
>
{
let
logPath
:
string
;
// check filename here for security
if
(
logType
===
'
TRIAL_LOG
'
)
{
if
(
!
[
'
trial.log
'
,
'
stderr
'
,
'
model.onnx
'
,
'
stdout
'
].
includes
(
fileName
))
{
logPath
=
path
.
join
(
this
.
rootDir
,
'
trials
'
,
trialJobId
,
'
trial.log
'
);
throw
new
Error
(
`File unaccessible:
${
fileName
}
`
);
}
else
if
(
logType
===
'
TRIAL_STDOUT
'
){
}
logPath
=
path
.
join
(
this
.
rootDir
,
'
trials
'
,
trialJobId
,
'
stdout
'
);
let
encoding
:
string
|
null
=
null
;
}
else
if
(
logType
===
'
TRIAL_ERROR
'
)
{
if
(
!
fileName
.
includes
(
'
.
'
)
||
fileName
.
match
(
/.*
\.(
txt|log
)
/g
))
{
logPath
=
path
.
join
(
this
.
rootDir
,
'
trials
'
,
trialJobId
,
'
stderr
'
);
encoding
=
'
utf8
'
;
}
else
{
}
throw
new
Error
(
'
unexpected log type
'
);
const
logPath
=
path
.
join
(
this
.
rootDir
,
'
trials
'
,
trialJobId
,
fileName
);
if
(
!
fs
.
existsSync
(
logPath
))
{
throw
new
Error
(
`File not found:
${
logPath
}
`
);
}
}
return
fs
.
promises
.
readFile
(
logPath
,
'
utf8
'
);
return
fs
.
promises
.
readFile
(
logPath
,
{
encoding
:
encoding
as
any
}
);
}
}
public
addTrialJobMetricListener
(
listener
:
(
metric
:
TrialJobMetric
)
=>
void
):
void
{
public
addTrialJobMetricListener
(
listener
:
(
metric
:
TrialJobMetric
)
=>
void
):
void
{
...
...
ts/nni_manager/training_service/pai/paiTrainingService.ts
View file @
403195f0
...
@@ -15,7 +15,7 @@ import { getLogger, Logger } from '../../common/log';
...
@@ -15,7 +15,7 @@ import { getLogger, Logger } from '../../common/log';
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
import
{
HyperParameters
,
NNIManagerIpConfig
,
TrainingService
,
HyperParameters
,
NNIManagerIpConfig
,
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
LogType
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
}
from
'
../../common/trainingService
'
;
import
{
delay
}
from
'
../../common/utils
'
;
import
{
delay
}
from
'
../../common/utils
'
;
import
{
ExperimentConfig
,
OpenpaiConfig
,
flattenConfig
,
toMegaBytes
}
from
'
../../common/experimentConfig
'
;
import
{
ExperimentConfig
,
OpenpaiConfig
,
flattenConfig
,
toMegaBytes
}
from
'
../../common/experimentConfig
'
;
...
@@ -23,10 +23,7 @@ import { PAIJobInfoCollector } from './paiJobInfoCollector';
...
@@ -23,10 +23,7 @@ import { PAIJobInfoCollector } from './paiJobInfoCollector';
import
{
PAIJobRestServer
}
from
'
./paiJobRestServer
'
;
import
{
PAIJobRestServer
}
from
'
./paiJobRestServer
'
;
import
{
PAITrialJobDetail
,
PAI_TRIAL_COMMAND_FORMAT
}
from
'
./paiConfig
'
;
import
{
PAITrialJobDetail
,
PAI_TRIAL_COMMAND_FORMAT
}
from
'
./paiConfig
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
{
import
{
generateParamFileName
,
getIPV4Address
,
uniqueString
}
from
'
../../common/utils
'
;
generateParamFileName
,
getIPV4Address
,
uniqueString
}
from
'
../../common/utils
'
;
import
{
CONTAINER_INSTALL_NNI_SHELL_FORMAT
}
from
'
../common/containerJobData
'
;
import
{
CONTAINER_INSTALL_NNI_SHELL_FORMAT
}
from
'
../common/containerJobData
'
;
import
{
execMkdir
,
validateCodeDir
,
execCopydir
}
from
'
../common/util
'
;
import
{
execMkdir
,
validateCodeDir
,
execCopydir
}
from
'
../common/util
'
;
...
@@ -127,7 +124,7 @@ class PAITrainingService implements TrainingService {
...
@@ -127,7 +124,7 @@ class PAITrainingService implements TrainingService {
return
jobs
;
return
jobs
;
}
}
public
async
getTrial
Log
(
_trialJobId
:
string
,
_
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrial
File
(
_trialJobId
:
string
,
_
fileName
:
string
):
Promise
<
string
|
Buffer
>
{
throw
new
MethodNotImplementedError
();
throw
new
MethodNotImplementedError
();
}
}
...
@@ -332,7 +329,7 @@ class PAITrainingService implements TrainingService {
...
@@ -332,7 +329,7 @@ class PAITrainingService implements TrainingService {
return
trialJobDetail
;
return
trialJobDetail
;
}
}
private
generateNNITrialCommand
(
trialJobDetail
:
PAITrialJobDetail
,
command
:
string
):
string
{
private
async
generateNNITrialCommand
(
trialJobDetail
:
PAITrialJobDetail
,
command
:
string
):
Promise
<
string
>
{
const
containerNFSExpCodeDir
=
`
${
this
.
config
.
containerStorageMountPoint
}
/
${
this
.
experimentId
}
/nni-code`
;
const
containerNFSExpCodeDir
=
`
${
this
.
config
.
containerStorageMountPoint
}
/
${
this
.
experimentId
}
/nni-code`
;
const
containerWorkingDir
:
string
=
`
${
this
.
config
.
containerStorageMountPoint
}
/
${
this
.
experimentId
}
/
${
trialJobDetail
.
id
}
`
;
const
containerWorkingDir
:
string
=
`
${
this
.
config
.
containerStorageMountPoint
}
/
${
this
.
experimentId
}
/
${
trialJobDetail
.
id
}
`
;
const
nniPaiTrialCommand
:
string
=
String
.
Format
(
const
nniPaiTrialCommand
:
string
=
String
.
Format
(
...
@@ -345,7 +342,7 @@ class PAITrainingService implements TrainingService {
...
@@ -345,7 +342,7 @@ class PAITrainingService implements TrainingService {
false
,
// multi-phase
false
,
// multi-phase
containerNFSExpCodeDir
,
containerNFSExpCodeDir
,
command
,
command
,
this
.
config
.
nniManagerIp
||
getIPV4Address
(),
this
.
config
.
nniManagerIp
||
await
getIPV4Address
(),
this
.
paiRestServerPort
,
this
.
paiRestServerPort
,
this
.
nniVersion
,
this
.
nniVersion
,
this
.
logCollection
this
.
logCollection
...
@@ -356,7 +353,7 @@ class PAITrainingService implements TrainingService {
...
@@ -356,7 +353,7 @@ class PAITrainingService implements TrainingService {
}
}
private
generateJobConfigInYamlFormat
(
trialJobDetail
:
PAITrialJobDetail
):
any
{
private
async
generateJobConfigInYamlFormat
(
trialJobDetail
:
PAITrialJobDetail
):
Promise
<
any
>
{
const
jobName
=
`nni_exp_
${
this
.
experimentId
}
_trial_
${
trialJobDetail
.
id
}
`
const
jobName
=
`nni_exp_
${
this
.
experimentId
}
_trial_
${
trialJobDetail
.
id
}
`
let
nniJobConfig
:
any
=
undefined
;
let
nniJobConfig
:
any
=
undefined
;
...
@@ -367,7 +364,7 @@ class PAITrainingService implements TrainingService {
...
@@ -367,7 +364,7 @@ class PAITrainingService implements TrainingService {
// Each command will be formatted to NNI style
// Each command will be formatted to NNI style
for
(
const
taskRoleIndex
in
nniJobConfig
.
taskRoles
)
{
for
(
const
taskRoleIndex
in
nniJobConfig
.
taskRoles
)
{
const
commands
=
nniJobConfig
.
taskRoles
[
taskRoleIndex
].
commands
const
commands
=
nniJobConfig
.
taskRoles
[
taskRoleIndex
].
commands
const
nniTrialCommand
=
this
.
generateNNITrialCommand
(
trialJobDetail
,
commands
.
join
(
"
&&
"
).
replace
(
/
([
"'$`
\\])
/g
,
'
\\
$1
'
));
const
nniTrialCommand
=
await
this
.
generateNNITrialCommand
(
trialJobDetail
,
commands
.
join
(
"
&&
"
).
replace
(
/
([
"'$`
\\])
/g
,
'
\\
$1
'
));
nniJobConfig
.
taskRoles
[
taskRoleIndex
].
commands
=
[
nniTrialCommand
]
nniJobConfig
.
taskRoles
[
taskRoleIndex
].
commands
=
[
nniTrialCommand
]
}
}
...
@@ -399,7 +396,7 @@ class PAITrainingService implements TrainingService {
...
@@ -399,7 +396,7 @@ class PAITrainingService implements TrainingService {
memoryMB
:
toMegaBytes
(
this
.
config
.
trialMemorySize
)
memoryMB
:
toMegaBytes
(
this
.
config
.
trialMemorySize
)
},
},
commands
:
[
commands
:
[
this
.
generateNNITrialCommand
(
trialJobDetail
,
this
.
config
.
trialCommand
)
await
this
.
generateNNITrialCommand
(
trialJobDetail
,
this
.
config
.
trialCommand
)
]
]
}
}
},
},
...
@@ -456,7 +453,7 @@ class PAITrainingService implements TrainingService {
...
@@ -456,7 +453,7 @@ class PAITrainingService implements TrainingService {
}
}
//Generate Job Configuration in yaml format
//Generate Job Configuration in yaml format
const
paiJobConfig
=
this
.
generateJobConfigInYamlFormat
(
trialJobDetail
);
const
paiJobConfig
=
await
this
.
generateJobConfigInYamlFormat
(
trialJobDetail
);
this
.
log
.
debug
(
paiJobConfig
);
this
.
log
.
debug
(
paiJobConfig
);
// Step 2. Submit PAI job via Rest call
// Step 2. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
...
...
ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
View file @
403195f0
...
@@ -16,7 +16,7 @@ import { getLogger, Logger } from '../../common/log';
...
@@ -16,7 +16,7 @@ import { getLogger, Logger } from '../../common/log';
import
{
ObservableTimer
}
from
'
../../common/observableTimer
'
;
import
{
ObservableTimer
}
from
'
../../common/observableTimer
'
;
import
{
import
{
HyperParameters
,
TrainingService
,
TrialJobApplicationForm
,
HyperParameters
,
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
LogType
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
}
from
'
../../common/trainingService
'
;
import
{
import
{
delay
,
generateParamFileName
,
getExperimentRootDir
,
getIPV4Address
,
getJobCancelStatus
,
delay
,
generateParamFileName
,
getExperimentRootDir
,
getIPV4Address
,
getJobCancelStatus
,
...
@@ -204,7 +204,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -204,7 +204,7 @@ class RemoteMachineTrainingService implements TrainingService {
* @param _trialJobId ID of trial job
* @param _trialJobId ID of trial job
* @param _logType 'TRIAL_LOG' | 'TRIAL_STDERR'
* @param _logType 'TRIAL_LOG' | 'TRIAL_STDERR'
*/
*/
public
async
getTrial
Log
(
_trialJobId
:
string
,
_
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrial
File
(
_trialJobId
:
string
,
_
fileName
:
string
):
Promise
<
string
|
Buffer
>
{
throw
new
MethodNotImplementedError
();
throw
new
MethodNotImplementedError
();
}
}
...
@@ -491,7 +491,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -491,7 +491,7 @@ class RemoteMachineTrainingService implements TrainingService {
cudaVisible
=
`CUDA_VISIBLE_DEVICES=" "`
;
cudaVisible
=
`CUDA_VISIBLE_DEVICES=" "`
;
}
}
}
}
const
nniManagerIp
:
string
=
this
.
config
.
nniManagerIp
?
this
.
config
.
nniManagerIp
:
getIPV4Address
();
const
nniManagerIp
:
string
=
this
.
config
.
nniManagerIp
?
this
.
config
.
nniManagerIp
:
await
getIPV4Address
();
if
(
this
.
remoteRestServerPort
===
undefined
)
{
if
(
this
.
remoteRestServerPort
===
undefined
)
{
const
restServer
:
RemoteMachineJobRestServer
=
component
.
get
(
RemoteMachineJobRestServer
);
const
restServer
:
RemoteMachineJobRestServer
=
component
.
get
(
RemoteMachineJobRestServer
);
this
.
remoteRestServerPort
=
restServer
.
clusterRestServerPort
;
this
.
remoteRestServerPort
=
restServer
.
clusterRestServerPort
;
...
...
ts/nni_manager/training_service/reusable/routerTrainingService.ts
View file @
403195f0
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
ExperimentConfig
,
RemoteConfig
,
OpenpaiConfig
}
from
'
../../common/experimentConfig
'
;
import
{
ExperimentConfig
,
RemoteConfig
,
OpenpaiConfig
}
from
'
../../common/experimentConfig
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
,
LogType
}
from
'
../../common/trainingService
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
import
{
delay
}
from
'
../../common/utils
'
;
import
{
delay
}
from
'
../../common/utils
'
;
import
{
PAITrainingService
}
from
'
../pai/paiTrainingService
'
;
import
{
PAITrainingService
}
from
'
../pai/paiTrainingService
'
;
import
{
RemoteMachineTrainingService
}
from
'
../remote_machine/remoteMachineTrainingService
'
;
import
{
RemoteMachineTrainingService
}
from
'
../remote_machine/remoteMachineTrainingService
'
;
...
@@ -52,7 +52,7 @@ class RouterTrainingService implements TrainingService {
...
@@ -52,7 +52,7 @@ class RouterTrainingService implements TrainingService {
return
await
this
.
internalTrainingService
.
getTrialJob
(
trialJobId
);
return
await
this
.
internalTrainingService
.
getTrialJob
(
trialJobId
);
}
}
public
async
getTrial
Log
(
_trialJobId
:
string
,
_
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrial
File
(
_trialJobId
:
string
,
_
fileName
:
string
):
Promise
<
string
|
Buffer
>
{
throw
new
MethodNotImplementedError
();
throw
new
MethodNotImplementedError
();
}
}
...
...
ts/nni_manager/training_service/reusable/trialDispatcher.ts
View file @
403195f0
...
@@ -13,7 +13,7 @@ import * as component from '../../common/component';
...
@@ -13,7 +13,7 @@ import * as component from '../../common/component';
import
{
NNIError
,
NNIErrorNames
,
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
NNIError
,
NNIErrorNames
,
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
getBasePort
,
getExperimentId
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getBasePort
,
getExperimentId
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobMetric
,
TrialJobStatus
,
LogType
}
from
'
../../common/trainingService
'
;
import
{
TrainingService
,
TrialJobApplicationForm
,
TrialJobMetric
,
TrialJobStatus
}
from
'
../../common/trainingService
'
;
import
{
delay
,
getExperimentRootDir
,
getIPV4Address
,
getLogLevel
,
getVersion
,
mkDirPSync
,
randomSelect
,
uniqueString
}
from
'
../../common/utils
'
;
import
{
delay
,
getExperimentRootDir
,
getIPV4Address
,
getLogLevel
,
getVersion
,
mkDirPSync
,
randomSelect
,
uniqueString
}
from
'
../../common/utils
'
;
import
{
ExperimentConfig
,
SharedStorageConfig
}
from
'
../../common/experimentConfig
'
;
import
{
ExperimentConfig
,
SharedStorageConfig
}
from
'
../../common/experimentConfig
'
;
import
{
GPU_INFO
,
INITIALIZED
,
KILL_TRIAL_JOB
,
NEW_TRIAL_JOB
,
REPORT_METRIC_DATA
,
SEND_TRIAL_JOB_PARAMETER
,
STDOUT
,
TRIAL_END
,
VERSION_CHECK
}
from
'
../../core/commands
'
;
import
{
GPU_INFO
,
INITIALIZED
,
KILL_TRIAL_JOB
,
NEW_TRIAL_JOB
,
REPORT_METRIC_DATA
,
SEND_TRIAL_JOB_PARAMETER
,
STDOUT
,
TRIAL_END
,
VERSION_CHECK
}
from
'
../../core/commands
'
;
...
@@ -157,7 +157,7 @@ class TrialDispatcher implements TrainingService {
...
@@ -157,7 +157,7 @@ class TrialDispatcher implements TrainingService {
return
trial
;
return
trial
;
}
}
public
async
getTrial
Log
(
_trialJobId
:
string
,
_
logType
:
LogType
):
Promise
<
string
>
{
public
async
getTrial
File
(
_trialJobId
:
string
,
_
fileName
:
string
):
Promise
<
string
|
Buffer
>
{
throw
new
MethodNotImplementedError
();
throw
new
MethodNotImplementedError
();
}
}
...
@@ -216,7 +216,7 @@ class TrialDispatcher implements TrainingService {
...
@@ -216,7 +216,7 @@ class TrialDispatcher implements TrainingService {
for
(
const
environmentService
of
this
.
environmentServiceList
)
{
for
(
const
environmentService
of
this
.
environmentServiceList
)
{
const
runnerSettings
:
RunnerSettings
=
new
RunnerSettings
();
const
runnerSettings
:
RunnerSettings
=
new
RunnerSettings
();
runnerSettings
.
nniManagerIP
=
this
.
config
.
nniManagerIp
===
undefined
?
getIPV4Address
()
:
this
.
config
.
nniManagerIp
;
runnerSettings
.
nniManagerIP
=
this
.
config
.
nniManagerIp
===
undefined
?
await
getIPV4Address
()
:
this
.
config
.
nniManagerIp
;
runnerSettings
.
nniManagerPort
=
getBasePort
()
+
1
;
runnerSettings
.
nniManagerPort
=
getBasePort
()
+
1
;
runnerSettings
.
commandChannel
=
environmentService
.
getCommandChannel
.
channelName
;
runnerSettings
.
commandChannel
=
environmentService
.
getCommandChannel
.
channelName
;
runnerSettings
.
enableGpuCollector
=
this
.
enableGpuScheduler
;
runnerSettings
.
enableGpuCollector
=
this
.
enableGpuScheduler
;
...
...
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment