Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
cd3a912a
Unverified
Commit
cd3a912a
authored
Nov 27, 2019
by
SparkSnail
Committed by
GitHub
Nov 27, 2019
Browse files
Merge pull request #218 from microsoft/master
merge master
parents
a0846f2a
e9cba778
Changes
375
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
353 additions
and
1155 deletions
+353
-1155
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+159
-22
src/sdk/pynni/tests/test_msg_dispatcher.py
src/sdk/pynni/tests/test_msg_dispatcher.py
+2
-20
src/sdk/pynni/tests/test_protocol.py
src/sdk/pynni/tests/test_protocol.py
+2
-20
src/sdk/pynni/tests/test_smartparam.py
src/sdk/pynni/tests/test_smartparam.py
+2
-19
src/sdk/pynni/tests/test_trial.py
src/sdk/pynni/tests/test_trial.py
+3
-20
src/sdk/pynni/tests/test_utils.py
src/sdk/pynni/tests/test_utils.py
+3
-20
src/webui/package.json
src/webui/package.json
+1
-1
src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
+69
-40
src/webui/src/components/trial-detail/Duration.tsx
src/webui/src/components/trial-detail/Duration.tsx
+32
-85
src/webui/src/components/trial-detail/Intermediate.tsx
src/webui/src/components/trial-detail/Intermediate.tsx
+29
-3
src/webui/src/components/trial-detail/Para.tsx
src/webui/src/components/trial-detail/Para.tsx
+1
-0
src/webui/src/components/trial-detail/TableList.tsx
src/webui/src/components/trial-detail/TableList.tsx
+5
-6
src/webui/src/static/function.ts
src/webui/src/static/function.ts
+1
-1
src/webui/src/static/interface.ts
src/webui/src/static/interface.ts
+8
-2
src/webui/src/static/model/trial.ts
src/webui/src/static/model/trial.ts
+18
-1
src/webui/yarn.lock
src/webui/yarn.lock
+7
-856
test/async_sharing_test/main.py
test/async_sharing_test/main.py
+3
-0
test/async_sharing_test/simple_tuner.py
test/async_sharing_test/simple_tuner.py
+3
-0
test/cli_test.py
test/cli_test.py
+3
-20
test/config_test.py
test/config_test.py
+2
-19
No files found.
src/sdk/pynni/tests/test_compressor.py
View file @
cd3a912a
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
unittest
import
TestCase
,
main
from
unittest
import
TestCase
,
main
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
nni.compression.torch
as
torch_compressor
import
nni.compression.torch
as
torch_compressor
import
math
if
tf
.
__version__
>=
'2.0'
:
if
tf
.
__version__
>=
'2.0'
:
import
nni.compression.tensorflow
as
tf_compressor
import
nni.compression.tensorflow
as
tf_compressor
def
get_tf_model
():
def
get_tf_model
():
model
=
tf
.
keras
.
models
.
Sequential
([
model
=
tf
.
keras
.
models
.
Sequential
([
tf
.
keras
.
layers
.
Conv2D
(
filters
=
5
,
kernel_size
=
7
,
input_shape
=
[
28
,
28
,
1
],
activation
=
'relu'
,
padding
=
"SAME"
),
tf
.
keras
.
layers
.
Conv2D
(
filters
=
5
,
kernel_size
=
7
,
input_shape
=
[
28
,
28
,
1
],
activation
=
'relu'
,
padding
=
"SAME"
),
...
@@ -20,43 +25,70 @@ def get_tf_model():
...
@@ -20,43 +25,70 @@ def get_tf_model():
tf
.
keras
.
layers
.
Dense
(
units
=
10
,
activation
=
'softmax'
),
tf
.
keras
.
layers
.
Dense
(
units
=
10
,
activation
=
'softmax'
),
])
])
model
.
compile
(
loss
=
"sparse_categorical_crossentropy"
,
model
.
compile
(
loss
=
"sparse_categorical_crossentropy"
,
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
1e-3
),
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
1e-3
),
metrics
=
[
"accuracy"
])
metrics
=
[
"accuracy"
])
return
model
return
model
class
TorchModel
(
torch
.
nn
.
Module
):
class
TorchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
5
,
5
,
1
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
5
,
5
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
5
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
5
,
10
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
5
,
10
,
5
,
1
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
10
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
10
,
100
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
10
,
100
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
))
)
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
conv2
(
x
))
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
))
)
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
10
)
x
=
x
.
view
(
-
1
,
4
*
4
*
10
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
tf2
(
func
):
def
tf2
(
func
):
def
test_tf2_func
(
*
args
):
def
test_tf2_func
(
*
args
):
if
tf
.
__version__
>=
'2.0'
:
if
tf
.
__version__
>=
'2.0'
:
func
(
*
args
)
func
(
*
args
)
return
test_tf2_func
return
test_tf2_func
k1
=
[[
1
]
*
3
]
*
3
# for fpgm filter pruner test
k2
=
[[
2
]
*
3
]
*
3
w
=
np
.
array
([[[[
i
+
1
]
*
3
]
*
3
]
*
5
for
i
in
range
(
10
)])
k3
=
[[
3
]
*
3
]
*
3
k4
=
[[
4
]
*
3
]
*
3
k5
=
[[
5
]
*
3
]
*
3
w
=
[[
k1
,
k2
,
k3
,
k4
,
k5
]]
*
10
class
CompressorTestCase
(
TestCase
):
class
CompressorTestCase
(
TestCase
):
def
test_torch_quantizer_modules_detection
(
self
):
# test if modules can be detected
model
=
TorchModel
()
config_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
8
,
'op_types'
:[
'Conv2d'
,
'Linear'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
8
,
'quant_start_step'
:
0
,
'op_types'
:[
'ReLU'
]
}]
model
.
relu
=
torch
.
nn
.
ReLU
()
quantizer
=
torch_compressor
.
QAT_Quantizer
(
model
,
config_list
)
quantizer
.
compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
assert
"conv1"
in
modules_to_compress_name
assert
"conv2"
in
modules_to_compress_name
assert
"fc1"
in
modules_to_compress_name
assert
"fc2"
in
modules_to_compress_name
assert
"relu"
in
modules_to_compress_name
assert
len
(
modules_to_compress_name
)
==
5
def
test_torch_level_pruner
(
self
):
def
test_torch_level_pruner
(
self
):
model
=
TorchModel
()
model
=
TorchModel
()
configure_list
=
[{
'sparsity'
:
0.8
,
'op_types'
:
[
'default'
]}]
configure_list
=
[{
'sparsity'
:
0.8
,
'op_types'
:
[
'default'
]}]
...
@@ -74,7 +106,7 @@ class CompressorTestCase(TestCase):
...
@@ -74,7 +106,7 @@ class CompressorTestCase(TestCase):
'quant_bits'
:
{
'quant_bits'
:
{
'weight'
:
8
,
'weight'
:
8
,
},
},
'op_types'
:[
'Conv2d'
,
'Linear'
]
'op_types'
:
[
'Conv2d'
,
'Linear'
]
}]
}]
torch_compressor
.
NaiveQuantizer
(
model
,
configure_list
).
compress
()
torch_compressor
.
NaiveQuantizer
(
model
,
configure_list
).
compress
()
...
@@ -84,16 +116,16 @@ class CompressorTestCase(TestCase):
...
@@ -84,16 +116,16 @@ class CompressorTestCase(TestCase):
def
test_torch_fpgm_pruner
(
self
):
def
test_torch_fpgm_pruner
(
self
):
"""
"""
With filters(kernels) defined as above (
k1 - k5
), it is obvious that
k3
is the Geometric Median
With filters(kernels)
weights
defined as above (
w
), it is obvious that
w[4] and w[5]
is the Geometric Median
which minimize the total geometric distance by defination of Geometric Median in this paper:
which minimize the total geometric distance by defination of Geometric Median in this paper:
Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
https://arxiv.org/pdf/1811.00250.pdf
https://arxiv.org/pdf/1811.00250.pdf
So if sparsity is 0.2, the expected masks should mask out
all k3
, this can be verified through:
So if sparsity is 0.2, the expected masks should mask out
w[4] and w[5]
, this can be verified through:
`all(torch.sum(masks, (
0
, 2, 3)).numpy() == np.array([
90
.,
90
.,
0
.,
90., 90
.]))`
`all(torch.sum(masks, (
1
, 2, 3)).numpy() == np.array([
45
.,
45
.,
45
.,
45., 0., 0., 45., 45., 45., 45
.]))`
If sparsity is 0.6, the expected masks should mask out
all k2, k3, k4
, this can be verified through:
If sparsity is 0.6, the expected masks should mask out
w[2] - w[7]
, this can be verified through:
`all(torch.sum(masks, (
0
, 2, 3)).numpy() == np.array([
9
0., 0., 0., 0.,
9
0.]))`
`all(torch.sum(masks, (
1
, 2, 3)).numpy() == np.array([
45., 45.,
0., 0., 0., 0., 0
., 0., 45., 45
.]))`
"""
"""
model
=
TorchModel
()
model
=
TorchModel
()
...
@@ -103,12 +135,12 @@ class CompressorTestCase(TestCase):
...
@@ -103,12 +135,12 @@ class CompressorTestCase(TestCase):
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
layer
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv2'
,
model
.
conv2
)
layer
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv2'
,
model
.
conv2
)
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
0
])
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
0
])
assert
all
(
torch
.
sum
(
masks
,
(
0
,
2
,
3
)).
numpy
()
==
np
.
array
([
90
.
,
90
.
,
0
.
,
90.
,
90
.
]))
assert
all
(
torch
.
sum
(
masks
,
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45
.
,
45
.
,
45
.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45
.
]))
pruner
.
update_epoch
(
1
)
pruner
.
update_epoch
(
1
)
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
])
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
])
assert
all
(
torch
.
sum
(
masks
,
(
0
,
2
,
3
)).
numpy
()
==
np
.
array
([
9
0.
,
0.
,
0.
,
0.
,
9
0.
]))
assert
all
(
torch
.
sum
(
masks
,
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45.
,
45.
,
0.
,
0.
,
0.
,
0.
,
0
.
,
0.
,
45.
,
45
.
]))
@
tf2
@
tf2
def
test_tf_fpgm_pruner
(
self
):
def
test_tf_fpgm_pruner
(
self
):
...
@@ -122,17 +154,122 @@ class CompressorTestCase(TestCase):
...
@@ -122,17 +154,122 @@ class CompressorTestCase(TestCase):
layer
=
tf_compressor
.
compressor
.
LayerInfo
(
model
.
layers
[
2
])
layer
=
tf_compressor
.
compressor
.
LayerInfo
(
model
.
layers
[
2
])
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
0
]).
numpy
()
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
0
]).
numpy
()
masks
=
masks
.
transpose
([
2
,
3
,
0
,
1
]).
transpose
([
1
,
0
,
2
,
3
])
masks
=
masks
.
reshape
((
-
1
,
masks
.
shape
[
-
1
])
)
.
transpose
([
1
,
0
])
assert
all
(
masks
.
sum
((
0
,
2
,
3
))
==
np
.
array
([
90
.
,
90
.
,
0
.
,
90.
,
90
.
]))
assert
all
(
masks
.
sum
((
1
))
==
np
.
array
([
45
.
,
45
.
,
45
.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45
.
]))
pruner
.
update_epoch
(
1
)
pruner
.
update_epoch
(
1
)
model
.
layers
[
2
].
set_weights
([
weights
[
0
],
weights
[
1
].
numpy
()])
model
.
layers
[
2
].
set_weights
([
weights
[
0
],
weights
[
1
].
numpy
()])
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
]).
numpy
()
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
]).
numpy
()
masks
=
masks
.
transpose
([
2
,
3
,
0
,
1
]).
transpose
([
1
,
0
,
2
,
3
])
masks
=
masks
.
reshape
((
-
1
,
masks
.
shape
[
-
1
])).
transpose
([
1
,
0
])
assert
all
(
masks
.
sum
((
1
))
==
np
.
array
([
45.
,
45.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
45.
,
45.
]))
def
test_torch_l1filter_pruner
(
self
):
"""
Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
PRUNING FILTERS FOR EFFICIENT CONVNETS,
https://arxiv.org/abs/1608.08710
So if sparsity is 0.2, the expected masks should mask out filter 0, this can be verified through:
`all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))`
If sparsity is 0.6, the expected masks should mask out filter 0,1,2, this can be verified through:
`all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))`
"""
w
=
np
.
array
([
np
.
zeros
((
3
,
3
,
3
)),
np
.
ones
((
3
,
3
,
3
)),
np
.
ones
((
3
,
3
,
3
))
*
2
,
np
.
ones
((
3
,
3
,
3
))
*
3
,
np
.
ones
((
3
,
3
,
3
))
*
4
])
model
=
TorchModel
()
config_list
=
[{
'sparsity'
:
0.2
,
'op_names'
:
[
'conv1'
]},
{
'sparsity'
:
0.6
,
'op_names'
:
[
'conv2'
]}]
pruner
=
torch_compressor
.
L1FilterPruner
(
model
,
config_list
)
model
.
conv1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv1'
,
model
.
conv1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv2'
,
model
.
conv2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
1
])
assert
all
(
torch
.
sum
(
mask1
,
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
0.
,
27.
,
27.
,
27.
,
27.
]))
assert
all
(
torch
.
sum
(
mask2
,
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
27.
,
27.
]))
def
test_torch_slim_pruner
(
self
):
"""
Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
Learning Efficient Convolutional Networks through Network Slimming,
https://arxiv.org/pdf/1708.06519.pdf
So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
`all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`
If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
`all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
"""
w
=
np
.
array
([
0
,
1
,
2
,
3
,
4
])
model
=
TorchModel
()
config_list
=
[{
'sparsity'
:
0.2
,
'op_types'
:
[
'BatchNorm2d'
]}]
model
.
bn1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
bn2
.
weight
.
data
=
torch
.
tensor
(
-
w
).
float
()
pruner
=
torch_compressor
.
SlimPruner
(
model
,
config_list
)
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn1'
,
model
.
bn1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn2'
,
model
.
bn2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
])
assert
all
(
mask1
.
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
assert
all
(
mask2
.
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
config_list
=
[{
'sparsity'
:
0.6
,
'op_types'
:
[
'BatchNorm2d'
]}]
model
.
bn1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
bn2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
pruner
=
torch_compressor
.
SlimPruner
(
model
,
config_list
)
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn1'
,
model
.
bn1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn2'
,
model
.
bn2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
])
assert
all
(
mask1
.
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
assert
all
(
mask2
.
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
def
test_torch_QAT_quantizer
(
self
):
model
=
TorchModel
()
config_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
8
,
'op_types'
:[
'Conv2d'
,
'Linear'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
8
,
'quant_start_step'
:
0
,
'op_types'
:[
'ReLU'
]
}]
model
.
relu
=
torch
.
nn
.
ReLU
()
quantizer
=
torch_compressor
.
QAT_Quantizer
(
model
,
config_list
)
quantizer
.
compress
()
# test quantize
# range not including 0
eps
=
1e-7
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
quantize_weight
=
quantizer
.
quantize_weight
(
weight
,
config_list
[
0
],
model
.
conv2
)
assert
math
.
isclose
(
model
.
conv2
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
zero_point
==
0
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
quantize_weight
=
quantizer
.
quantize_weight
(
weight
,
config_list
[
0
],
model
.
conv2
)
assert
math
.
isclose
(
model
.
conv2
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
zero_point
in
(
42
,
43
)
assert
all
(
masks
.
sum
((
0
,
2
,
3
))
==
np
.
array
([
90.
,
0.
,
0.
,
0.
,
90.
]))
# test ema
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
out
=
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
tracked_min_biased
,
0
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
tracked_max_biased
,
0.002
,
abs_tol
=
eps
)
quantizer
.
step
()
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
out
=
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
src/sdk/pynni/tests/test_msg_dispatcher.py
View file @
cd3a912a
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation.
#
# Licensed under the MIT license.
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
json
import
json
from
io
import
BytesIO
from
io
import
BytesIO
...
...
src/sdk/pynni/tests/test_protocol.py
View file @
cd3a912a
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation.
#
# Licensed under the MIT license.
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
nni.protocol
import
nni.protocol
from
nni.protocol
import
CommandType
,
send
,
receive
from
nni.protocol
import
CommandType
,
send
,
receive
...
...
src/sdk/pynni/tests/test_smartparam.py
View file @
cd3a912a
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation.
#
# Licensed under the MIT license.
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
os
import
os
...
...
src/sdk/pynni/tests/test_trial.py
View file @
cd3a912a
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation.
#
# Licensed under the MIT license.
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
nni
import
nni
import
nni.platform.test
as
test_platform
import
nni.platform.test
as
test_platform
...
@@ -93,4 +76,4 @@ class TrialTestCase(TestCase):
...
@@ -93,4 +76,4 @@ class TrialTestCase(TestCase):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
\ No newline at end of file
src/sdk/pynni/tests/test_utils.py
View file @
cd3a912a
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation.
#
# Licensed under the MIT license.
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
from
unittest
import
TestCase
,
main
from
unittest
import
TestCase
,
main
...
@@ -102,4 +85,4 @@ class UtilsTestCase(TestCase):
...
@@ -102,4 +85,4 @@ class UtilsTestCase(TestCase):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
\ No newline at end of file
src/webui/package.json
View file @
cd3a912a
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
"copy-to-clipboard"
:
"^3.0.8"
,
"copy-to-clipboard"
:
"^3.0.8"
,
"css-loader"
:
"0.28.7"
,
"css-loader"
:
"0.28.7"
,
"dotenv"
:
"^8.0.0"
,
"dotenv"
:
"^8.0.0"
,
"echarts"
:
"^4.
1
.0"
,
"echarts"
:
"^4.
5
.0"
,
"echarts-for-react"
:
"^2.0.14"
,
"echarts-for-react"
:
"^2.0.14"
,
"file-loader"
:
"^4.1.0"
,
"file-loader"
:
"^4.1.0"
,
"fork-ts-checker-webpack-plugin"
:
"^1.5.0"
,
"fork-ts-checker-webpack-plugin"
:
"^1.5.0"
,
...
...
src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
View file @
cd3a912a
...
@@ -3,7 +3,7 @@ import { Switch } from 'antd';
...
@@ -3,7 +3,7 @@ import { Switch } from 'antd';
import
ReactEcharts
from
'
echarts-for-react
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
{
EXPERIMENT
,
TRIALS
}
from
'
../../static/datamodel
'
;
import
{
EXPERIMENT
,
TRIALS
}
from
'
../../static/datamodel
'
;
import
{
Trial
}
from
'
../../static/model/trial
'
;
import
{
Trial
}
from
'
../../static/model/trial
'
;
import
{
TooltipForAccuracy
}
from
'
../../static/interface
'
;
import
{
TooltipForAccuracy
,
EventMap
}
from
'
../../static/interface
'
;
require
(
'
echarts/lib/chart/scatter
'
);
require
(
'
echarts/lib/chart/scatter
'
);
require
(
'
echarts/lib/component/tooltip
'
);
require
(
'
echarts/lib/component/tooltip
'
);
require
(
'
echarts/lib/component/title
'
);
require
(
'
echarts/lib/component/title
'
);
...
@@ -16,12 +16,18 @@ interface DefaultPointProps {
...
@@ -16,12 +16,18 @@ interface DefaultPointProps {
interface
DefaultPointState
{
interface
DefaultPointState
{
bestCurveEnabled
:
boolean
;
bestCurveEnabled
:
boolean
;
startY
:
number
;
// dataZoomY
endY
:
number
;
}
}
class
DefaultPoint
extends
React
.
Component
<
DefaultPointProps
,
DefaultPointState
>
{
class
DefaultPoint
extends
React
.
Component
<
DefaultPointProps
,
DefaultPointState
>
{
constructor
(
props
:
DefaultPointProps
)
{
constructor
(
props
:
DefaultPointProps
)
{
super
(
props
);
super
(
props
);
this
.
state
=
{
bestCurveEnabled
:
false
};
this
.
state
=
{
bestCurveEnabled
:
false
,
startY
:
0
,
// dataZoomY
endY
:
100
,
};
}
}
loadDefault
=
(
checked
:
boolean
)
=>
{
loadDefault
=
(
checked
:
boolean
)
=>
{
...
@@ -35,6 +41,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
...
@@ -35,6 +41,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
render
()
{
render
()
{
const
graph
=
this
.
generateGraph
();
const
graph
=
this
.
generateGraph
();
const
accNodata
=
(
graph
===
EmptyGraph
?
'
No data
'
:
''
);
const
accNodata
=
(
graph
===
EmptyGraph
?
'
No data
'
:
''
);
const
onEvents
=
{
'
dataZoom
'
:
this
.
metricDataZoom
};
return
(
return
(
<
div
>
<
div
>
...
@@ -53,6 +60,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
...
@@ -53,6 +60,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}
}
}
}
theme
=
"my_theme"
theme
=
"my_theme"
notMerge
=
{
true
}
// update now
notMerge
=
{
true
}
// update now
onEvents
=
{
onEvents
}
/>
/>
<
div
className
=
"showMess"
>
{
accNodata
}
</
div
>
<
div
className
=
"showMess"
>
{
accNodata
}
</
div
>
</
div
>
</
div
>
...
@@ -64,14 +72,66 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
...
@@ -64,14 +72,66 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
if
(
trials
.
length
===
0
)
{
if
(
trials
.
length
===
0
)
{
return
EmptyGraph
;
return
EmptyGraph
;
}
}
const
graph
=
generateGraphConfig
(
trials
[
trials
.
length
-
1
].
sequenceId
);
const
graph
=
this
.
generateGraphConfig
(
trials
[
trials
.
length
-
1
].
sequenceId
);
if
(
this
.
state
.
bestCurveEnabled
)
{
if
(
this
.
state
.
bestCurveEnabled
)
{
(
graph
as
any
).
series
=
[
generateBestCurveSeries
(
trials
),
generateScatterSeries
(
trials
)
];
(
graph
as
any
).
series
=
[
generateBestCurveSeries
(
trials
),
generateScatterSeries
(
trials
)];
}
else
{
}
else
{
(
graph
as
any
).
series
=
[
generateScatterSeries
(
trials
)
];
(
graph
as
any
).
series
=
[
generateScatterSeries
(
trials
)];
}
}
return
graph
;
return
graph
;
}
}
private
generateGraphConfig
(
maxSequenceId
:
number
)
{
const
{
startY
,
endY
}
=
this
.
state
;
return
{
grid
:
{
left
:
'
8%
'
,
},
tooltip
:
{
trigger
:
'
item
'
,
enterable
:
true
,
position
:
(
point
:
Array
<
number
>
,
data
:
TooltipForAccuracy
)
=>
(
[(
data
.
data
[
0
]
<
maxSequenceId
?
point
[
0
]
:
(
point
[
0
]
-
300
)),
80
]
),
formatter
:
(
data
:
TooltipForAccuracy
)
=>
(
'
<div class="tooldetailAccuracy">
'
+
'
<div>Trial No.:
'
+
data
.
data
[
0
]
+
'
</div>
'
+
'
<div>Default metric:
'
+
data
.
data
[
1
]
+
'
</div>
'
+
'
<div>Parameters: <pre>
'
+
JSON
.
stringify
(
data
.
data
[
2
],
null
,
4
)
+
'
</pre></div>
'
+
'
</div>
'
),
},
dataZoom
:
[
{
id
:
'
dataZoomY
'
,
type
:
'
inside
'
,
yAxisIndex
:
[
0
],
filterMode
:
'
empty
'
,
start
:
startY
,
end
:
endY
}
],
xAxis
:
{
name
:
'
Trial
'
,
type
:
'
category
'
,
},
yAxis
:
{
name
:
'
Default metric
'
,
type
:
'
value
'
,
scale
:
true
,
},
series
:
undefined
,
};
}
private
metricDataZoom
=
(
e
:
EventMap
)
=>
{
if
(
e
.
batch
!==
undefined
)
{
this
.
setState
(()
=>
({
startY
:
(
e
.
batch
[
0
].
start
!==
null
?
e
.
batch
[
0
].
start
:
0
),
endY
:
(
e
.
batch
[
0
].
end
!==
null
?
e
.
batch
[
0
].
end
:
100
)
}));
}
}
}
}
const
EmptyGraph
=
{
const
EmptyGraph
=
{
...
@@ -85,41 +145,10 @@ const EmptyGraph = {
...
@@ -85,41 +145,10 @@ const EmptyGraph = {
yAxis
:
{
yAxis
:
{
name
:
'
Default metric
'
,
name
:
'
Default metric
'
,
type
:
'
value
'
,
type
:
'
value
'
,
scale
:
true
,
}
}
};
};
function
generateGraphConfig
(
maxSequenceId
:
number
)
{
return
{
grid
:
{
left
:
'
8%
'
,
},
tooltip
:
{
trigger
:
'
item
'
,
enterable
:
true
,
position
:
(
point
:
Array
<
number
>
,
data
:
TooltipForAccuracy
)
=>
(
[
(
data
.
data
[
0
]
<
maxSequenceId
?
point
[
0
]
:
(
point
[
0
]
-
300
)),
80
]
),
formatter
:
(
data
:
TooltipForAccuracy
)
=>
(
'
<div class="tooldetailAccuracy">
'
+
'
<div>Trial No.:
'
+
data
.
data
[
0
]
+
'
</div>
'
+
'
<div>Default metric:
'
+
data
.
data
[
1
]
+
'
</div>
'
+
'
<div>Parameters: <pre>
'
+
JSON
.
stringify
(
data
.
data
[
2
],
null
,
4
)
+
'
</pre></div>
'
+
'
</div>
'
),
},
xAxis
:
{
name
:
'
Trial
'
,
type
:
'
category
'
,
},
yAxis
:
{
name
:
'
Default metric
'
,
type
:
'
value
'
,
scale
:
true
,
},
series
:
undefined
,
};
}
function
generateScatterSeries
(
trials
:
Trial
[])
{
function
generateScatterSeries
(
trials
:
Trial
[])
{
const
data
=
trials
.
map
(
trial
=>
[
const
data
=
trials
.
map
(
trial
=>
[
trial
.
sequenceId
,
trial
.
sequenceId
,
...
@@ -135,17 +164,17 @@ function generateScatterSeries(trials: Trial[]) {
...
@@ -135,17 +164,17 @@ function generateScatterSeries(trials: Trial[]) {
function
generateBestCurveSeries
(
trials
:
Trial
[])
{
function
generateBestCurveSeries
(
trials
:
Trial
[])
{
let
best
=
trials
[
0
];
let
best
=
trials
[
0
];
const
data
=
[[
best
.
sequenceId
,
best
.
accuracy
,
best
.
description
.
parameters
]];
const
data
=
[[
best
.
sequenceId
,
best
.
accuracy
,
best
.
description
.
parameters
]];
for
(
let
i
=
1
;
i
<
trials
.
length
;
i
++
)
{
for
(
let
i
=
1
;
i
<
trials
.
length
;
i
++
)
{
const
trial
=
trials
[
i
];
const
trial
=
trials
[
i
];
const
delta
=
trial
.
accuracy
!
-
best
.
accuracy
!
;
const
delta
=
trial
.
accuracy
!
-
best
.
accuracy
!
;
const
better
=
(
EXPERIMENT
.
optimizeMode
===
'
minimize
'
)
?
(
delta
<
0
)
:
(
delta
>
0
);
const
better
=
(
EXPERIMENT
.
optimizeMode
===
'
minimize
'
)
?
(
delta
<
0
)
:
(
delta
>
0
);
if
(
better
)
{
if
(
better
)
{
data
.
push
([
trial
.
sequenceId
,
trial
.
accuracy
,
trial
.
description
.
parameters
]);
data
.
push
([
trial
.
sequenceId
,
trial
.
accuracy
,
trial
.
description
.
parameters
]);
best
=
trial
;
best
=
trial
;
}
else
{
}
else
{
data
.
push
([
trial
.
sequenceId
,
best
.
accuracy
,
trial
.
description
.
parameters
]);
data
.
push
([
trial
.
sequenceId
,
best
.
accuracy
,
trial
.
description
.
parameters
]);
}
}
}
}
...
...
src/webui/src/components/trial-detail/Duration.tsx
View file @
cd3a912a
import
*
as
React
from
'
react
'
;
import
*
as
React
from
'
react
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
{
TableObj
}
from
'
src/static/interface
'
;
import
{
TableObj
,
EventMap
}
from
'
src/static/interface
'
;
import
{
filterDuration
}
from
'
src/static/function
'
;
import
{
filterDuration
}
from
'
src/static/function
'
;
require
(
'
echarts/lib/chart/bar
'
);
require
(
'
echarts/lib/chart/bar
'
);
require
(
'
echarts/lib/component/tooltip
'
);
require
(
'
echarts/lib/component/tooltip
'
);
...
@@ -17,7 +17,8 @@ interface DurationProps {
...
@@ -17,7 +17,8 @@ interface DurationProps {
}
}
interface
DurationState
{
interface
DurationState
{
durationSource
:
{};
startDuration
:
number
;
// for record data zoom
endDuration
:
number
;
}
}
class
Duration
extends
React
.
Component
<
DurationProps
,
DurationState
>
{
class
Duration
extends
React
.
Component
<
DurationProps
,
DurationState
>
{
...
@@ -26,63 +27,13 @@ class Duration extends React.Component<DurationProps, DurationState> {
...
@@ -26,63 +27,13 @@ class Duration extends React.Component<DurationProps, DurationState> {
super
(
props
);
super
(
props
);
this
.
state
=
{
this
.
state
=
{
durationSource
:
this
.
initDuration
(
this
.
props
.
source
),
startDuration
:
0
,
// for record data zoom
};
endDuration
:
100
,
}
initDuration
=
(
source
:
Array
<
TableObj
>
)
=>
{
const
trialId
:
Array
<
string
>
=
[];
const
trialTime
:
Array
<
number
>
=
[];
const
trialJobs
=
source
.
filter
(
filterDuration
);
Object
.
keys
(
trialJobs
).
map
(
item
=>
{
const
temp
=
trialJobs
[
item
];
trialId
.
push
(
temp
.
sequenceId
);
trialTime
.
push
(
temp
.
duration
);
});
return
{
tooltip
:
{
trigger
:
'
axis
'
,
axisPointer
:
{
type
:
'
shadow
'
}
},
grid
:
{
bottom
:
'
3%
'
,
containLabel
:
true
,
left
:
'
1%
'
,
right
:
'
4%
'
},
dataZoom
:
[{
type
:
'
slider
'
,
name
:
'
trial
'
,
filterMode
:
'
filter
'
,
yAxisIndex
:
0
,
orient
:
'
vertical
'
},
{
type
:
'
slider
'
,
name
:
'
trial
'
,
filterMode
:
'
filter
'
,
xAxisIndex
:
0
}],
xAxis
:
{
name
:
'
Time
'
,
type
:
'
value
'
,
},
yAxis
:
{
name
:
'
Trial
'
,
type
:
'
category
'
,
data
:
trialId
},
series
:
[{
type
:
'
bar
'
,
data
:
trialTime
}]
};
};
}
}
getOption
=
(
dataObj
:
Runtrial
)
=>
{
getOption
=
(
dataObj
:
Runtrial
)
=>
{
const
{
startDuration
,
endDuration
}
=
this
.
state
;
return
{
return
{
tooltip
:
{
tooltip
:
{
trigger
:
'
axis
'
,
trigger
:
'
axis
'
,
...
@@ -96,19 +47,16 @@ class Duration extends React.Component<DurationProps, DurationState> {
...
@@ -96,19 +47,16 @@ class Duration extends React.Component<DurationProps, DurationState> {
left
:
'
1%
'
,
left
:
'
1%
'
,
right
:
'
4%
'
right
:
'
4%
'
},
},
dataZoom
:
[
dataZoom
:
[{
{
type
:
'
slider
'
,
id
:
'
dataZoomY
'
,
name
:
'
trial
'
,
type
:
'
inside
'
,
filterMode
:
'
filter
'
,
yAxisIndex
:
[
0
],
yAxisIndex
:
0
,
filterMode
:
'
empty
'
,
orient
:
'
vertical
'
start
:
startDuration
,
},
{
end
:
endDuration
type
:
'
slider
'
,
},
name
:
'
trial
'
,
],
filterMode
:
'
filter
'
,
xAxisIndex
:
0
}],
xAxis
:
{
xAxis
:
{
name
:
'
Time
'
,
name
:
'
Time
'
,
type
:
'
value
'
,
type
:
'
value
'
,
...
@@ -140,21 +88,7 @@ class Duration extends React.Component<DurationProps, DurationState> {
...
@@ -140,21 +88,7 @@ class Duration extends React.Component<DurationProps, DurationState> {
trialId
:
trialId
,
trialId
:
trialId
,
trialTime
:
trialTime
trialTime
:
trialTime
});
});
this
.
setState
({
return
this
.
getOption
(
trialRun
[
0
]);
durationSource
:
this
.
getOption
(
trialRun
[
0
])
});
}
componentDidMount
()
{
const
{
source
}
=
this
.
props
;
this
.
drawDurationGraph
(
source
);
}
componentWillReceiveProps
(
nextProps
:
DurationProps
)
{
const
{
whichGraph
,
source
}
=
nextProps
;
if
(
whichGraph
===
'
3
'
)
{
this
.
drawDurationGraph
(
source
);
}
}
}
shouldComponentUpdate
(
nextProps
:
DurationProps
,
nextState
:
DurationState
)
{
shouldComponentUpdate
(
nextProps
:
DurationProps
,
nextState
:
DurationState
)
{
...
@@ -183,18 +117,31 @@ class Duration extends React.Component<DurationProps, DurationState> {
...
@@ -183,18 +117,31 @@ class Duration extends React.Component<DurationProps, DurationState> {
}
}
render
()
{
render
()
{
const
{
durationSource
}
=
this
.
state
;
const
{
source
}
=
this
.
props
;
const
graph
=
this
.
drawDurationGraph
(
source
);
const
onEvents
=
{
'
dataZoom
'
:
this
.
durationDataZoom
};
return
(
return
(
<
div
>
<
div
>
<
ReactEcharts
<
ReactEcharts
option
=
{
durationSource
}
option
=
{
graph
}
style
=
{
{
width
:
'
95%
'
,
height
:
412
,
margin
:
'
0 auto
'
}
}
style
=
{
{
width
:
'
95%
'
,
height
:
412
,
margin
:
'
0 auto
'
}
}
theme
=
"my_theme"
theme
=
"my_theme"
notMerge
=
{
true
}
// update now
notMerge
=
{
true
}
// update now
onEvents
=
{
onEvents
}
/>
/>
</
div
>
</
div
>
);
);
}
}
private
durationDataZoom
=
(
e
:
EventMap
)
=>
{
if
(
e
.
batch
!==
undefined
)
{
this
.
setState
(()
=>
({
startDuration
:
(
e
.
batch
[
0
].
start
!==
null
?
e
.
batch
[
0
].
start
:
0
),
endDuration
:
(
e
.
batch
[
0
].
end
!==
null
?
e
.
batch
[
0
].
end
:
100
)
}));
}
}
}
}
export
default
Duration
;
export
default
Duration
;
src/webui/src/components/trial-detail/Intermediate.tsx
View file @
cd3a912a
import
*
as
React
from
'
react
'
;
import
*
as
React
from
'
react
'
;
import
{
Row
,
Button
,
Switch
}
from
'
antd
'
;
import
{
Row
,
Button
,
Switch
}
from
'
antd
'
;
import
{
TooltipForIntermediate
,
TableObj
,
Intermedia
}
from
'
../../static/interface
'
;
import
{
TooltipForIntermediate
,
TableObj
,
Intermedia
,
EventMap
}
from
'
../../static/interface
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
require
(
'
echarts/lib/component/tooltip
'
);
require
(
'
echarts/lib/component/tooltip
'
);
require
(
'
echarts/lib/component/title
'
);
require
(
'
echarts/lib/component/title
'
);
...
@@ -14,6 +14,8 @@ interface IntermediateState {
...
@@ -14,6 +14,8 @@ interface IntermediateState {
isFilter
:
boolean
;
isFilter
:
boolean
;
length
:
number
;
length
:
number
;
clickCounts
:
number
;
// user filter intermediate click confirm btn's counts
clickCounts
:
number
;
// user filter intermediate click confirm btn's counts
startMediaY
:
number
;
endMediaY
:
number
;
}
}
interface
IntermediateProps
{
interface
IntermediateProps
{
...
@@ -38,7 +40,9 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
...
@@ -38,7 +40,9 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
isLoadconfirmBtn
:
false
,
isLoadconfirmBtn
:
false
,
isFilter
:
false
,
isFilter
:
false
,
length
:
100000
,
length
:
100000
,
clickCounts
:
0
clickCounts
:
0
,
startMediaY
:
0
,
endMediaY
:
100
};
};
}
}
...
@@ -48,6 +52,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
...
@@ -48,6 +52,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
length
:
source
.
length
,
length
:
source
.
length
,
detailSource
:
source
detailSource
:
source
});
});
const
{
startMediaY
,
endMediaY
}
=
this
.
state
;
const
trialIntermediate
:
Array
<
Intermedia
>
=
[];
const
trialIntermediate
:
Array
<
Intermedia
>
=
[];
Object
.
keys
(
source
).
map
(
item
=>
{
Object
.
keys
(
source
).
map
(
item
=>
{
const
temp
=
source
[
item
];
const
temp
=
source
[
item
];
...
@@ -113,6 +118,16 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
...
@@ -113,6 +118,16 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
type
:
'
value
'
,
type
:
'
value
'
,
name
:
'
Metric
'
name
:
'
Metric
'
},
},
dataZoom
:
[
{
id
:
'
dataZoomY
'
,
type
:
'
inside
'
,
yAxisIndex
:
[
0
],
filterMode
:
'
empty
'
,
start
:
startMediaY
,
end
:
endMediaY
}
],
series
:
trialIntermediate
series
:
trialIntermediate
};
};
this
.
setState
({
this
.
setState
({
...
@@ -258,6 +273,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
...
@@ -258,6 +273,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
render
()
{
render
()
{
const
{
interSource
,
isLoadconfirmBtn
,
isFilter
}
=
this
.
state
;
const
{
interSource
,
isLoadconfirmBtn
,
isFilter
}
=
this
.
state
;
const
IntermediateEvents
=
{
'
dataZoom
'
:
this
.
intermediateDataZoom
};
return
(
return
(
<
div
>
<
div
>
{
/* style in para.scss */
}
{
/* style in para.scss */
}
...
@@ -265,7 +281,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
...
@@ -265,7 +281,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
{
{
isFilter
isFilter
?
?
<
span
style
=
{
{
marginRight
:
15
}
}
>
<
span
style
=
{
{
marginRight
:
15
}
}
>
<
span
className
=
"filter-x"
>
# Intermediate result
</
span
>
<
span
className
=
"filter-x"
>
# Intermediate result
</
span
>
<
input
<
input
// placeholder="point"
// placeholder="point"
...
@@ -306,12 +322,22 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
...
@@ -306,12 +322,22 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
option
=
{
interSource
}
option
=
{
interSource
}
style
=
{
{
width
:
'
100%
'
,
height
:
418
,
margin
:
'
0 auto
'
}
}
style
=
{
{
width
:
'
100%
'
,
height
:
418
,
margin
:
'
0 auto
'
}
}
notMerge
=
{
true
}
// update now
notMerge
=
{
true
}
// update now
onEvents
=
{
IntermediateEvents
}
/>
/>
<
div
className
=
"yAxis"
>
# Intermediate result
</
div
>
<
div
className
=
"yAxis"
>
# Intermediate result
</
div
>
</
Row
>
</
Row
>
</
div
>
</
div
>
);
);
}
}
private
intermediateDataZoom
=
(
e
:
EventMap
)
=>
{
if
(
e
.
batch
!==
undefined
)
{
this
.
setState
(()
=>
({
startMediaY
:
(
e
.
batch
[
0
].
start
!==
null
?
e
.
batch
[
0
].
start
:
0
),
endMediaY
:
(
e
.
batch
[
0
].
end
!==
null
?
e
.
batch
[
0
].
end
:
100
)
}));
}
}
}
}
export
default
Intermediate
;
export
default
Intermediate
;
src/webui/src/components/trial-detail/Para.tsx
View file @
cd3a912a
...
@@ -275,6 +275,7 @@ class Para extends React.Component<ParaProps, ParaState> {
...
@@ -275,6 +275,7 @@ class Para extends React.Component<ParaProps, ParaState> {
parallelAxis
.
push
({
parallelAxis
.
push
({
dim
:
i
,
dim
:
i
,
name
:
'
default metric
'
,
name
:
'
default metric
'
,
scale
:
true
,
nameTextStyle
:
{
nameTextStyle
:
{
fontWeight
:
700
fontWeight
:
700
}
}
...
...
src/webui/src/components/trial-detail/TableList.tsx
View file @
cd3a912a
...
@@ -586,17 +586,16 @@ const AccuracyColumnConfig: ColumnProps<TableRecord> = {
...
@@ -586,17 +586,16 @@ const AccuracyColumnConfig: ColumnProps<TableRecord> = {
dataIndex
:
'
accuracy
'
,
dataIndex
:
'
accuracy
'
,
width
:
120
,
width
:
120
,
sorter
:
(
a
,
b
,
sortOrder
)
=>
{
sorter
:
(
a
,
b
,
sortOrder
)
=>
{
if
(
a
.
accuracy
===
undefined
)
{
if
(
a
.
latestAccuracy
===
undefined
)
{
return
sortOrder
===
'
ascend
'
?
-
1
:
1
;
}
else
if
(
b
.
accuracy
===
undefined
)
{
return
sortOrder
===
'
ascend
'
?
1
:
-
1
;
return
sortOrder
===
'
ascend
'
?
1
:
-
1
;
}
else
if
(
b
.
latestAccuracy
===
undefined
)
{
return
sortOrder
===
'
ascend
'
?
-
1
:
1
;
}
else
{
}
else
{
return
a
.
a
ccuracy
-
b
.
a
ccuracy
;
return
a
.
latestA
ccuracy
-
b
.
latestA
ccuracy
;
}
}
},
},
render
:
(
text
,
record
)
=>
(
render
:
(
text
,
record
)
=>
(
// TODO: is this needed?
<
div
>
{
record
.
formattedLatestAccuracy
}
</
div
>
<
div
>
{
record
.
latestAccuracy
}
</
div
>
)
)
};
};
...
...
src/webui/src/static/function.ts
View file @
cd3a912a
...
@@ -186,5 +186,5 @@ function formatAccuracy(accuracy: number): string {
...
@@ -186,5 +186,5 @@ function formatAccuracy(accuracy: number): string {
export
{
export
{
convertTime
,
convertDuration
,
getFinalResult
,
getFinal
,
downFile
,
convertTime
,
convertDuration
,
getFinalResult
,
getFinal
,
downFile
,
intermediateGraphOption
,
killJob
,
filterByStatus
,
filterDuration
,
intermediateGraphOption
,
killJob
,
filterByStatus
,
filterDuration
,
formatAccuracy
,
formatTimestamp
,
metricAccuracy
,
formatAccuracy
,
formatTimestamp
,
metricAccuracy
};
};
src/webui/src/static/interface.ts
View file @
cd3a912a
...
@@ -24,7 +24,8 @@ interface TableRecord {
...
@@ -24,7 +24,8 @@ interface TableRecord {
status
:
string
;
status
:
string
;
intermediateCount
:
number
;
intermediateCount
:
number
;
accuracy
?:
number
;
accuracy
?:
number
;
latestAccuracy
:
string
;
// formatted string
latestAccuracy
:
number
|
undefined
;
formattedLatestAccuracy
:
string
;
// format (LATEST/FINAL)
}
}
interface
SearchSpace
{
interface
SearchSpace
{
...
@@ -81,6 +82,7 @@ interface Dimobj {
...
@@ -81,6 +82,7 @@ interface Dimobj {
axisLabel
?:
object
;
axisLabel
?:
object
;
axisLine
?:
object
;
axisLine
?:
object
;
nameTextStyle
?:
object
;
nameTextStyle
?:
object
;
scale
?:
boolean
;
}
}
interface
ParaObj
{
interface
ParaObj
{
...
@@ -179,9 +181,13 @@ interface NNIManagerStatus {
...
@@ -179,9 +181,13 @@ interface NNIManagerStatus {
errors
:
string
[];
errors
:
string
[];
}
}
interface
EventMap
{
[
key
:
string
]:
()
=>
void
;
}
export
{
export
{
TableObj
,
TableRecord
,
Parameters
,
ExperimentProfile
,
AccurPoint
,
TableObj
,
TableRecord
,
Parameters
,
ExperimentProfile
,
AccurPoint
,
DetailAccurPoint
,
TooltipForAccuracy
,
ParaObj
,
Dimobj
,
FinalType
,
DetailAccurPoint
,
TooltipForAccuracy
,
ParaObj
,
Dimobj
,
FinalType
,
TooltipForIntermediate
,
SearchSpace
,
Intermedia
,
MetricDataRecord
,
TrialJobInfo
,
TooltipForIntermediate
,
SearchSpace
,
Intermedia
,
MetricDataRecord
,
TrialJobInfo
,
NNIManagerStatus
,
NNIManagerStatus
,
EventMap
};
};
src/webui/src/static/model/trial.ts
View file @
cd3a912a
...
@@ -46,6 +46,22 @@ class Trial implements TableObj {
...
@@ -46,6 +46,22 @@ class Trial implements TableObj {
return
this
.
metricsInitialized
&&
this
.
finalAcc
!==
undefined
&&
!
isNaN
(
this
.
finalAcc
);
return
this
.
metricsInitialized
&&
this
.
finalAcc
!==
undefined
&&
!
isNaN
(
this
.
finalAcc
);
}
}
get
latestAccuracy
():
number
|
undefined
{
if
(
this
.
accuracy
!==
undefined
)
{
return
this
.
accuracy
;
}
else
if
(
this
.
intermediates
.
length
>
0
)
{
// TODO: support intermeidate result is dict
const
temp
=
this
.
intermediates
[
this
.
intermediates
.
length
-
1
];
if
(
temp
!==
undefined
)
{
return
JSON
.
parse
(
temp
.
data
);
}
else
{
return
undefined
;
}
}
else
{
return
undefined
;
}
}
/* table obj start */
/* table obj start */
get
tableRecord
():
TableRecord
{
get
tableRecord
():
TableRecord
{
...
@@ -62,7 +78,8 @@ class Trial implements TableObj {
...
@@ -62,7 +78,8 @@ class Trial implements TableObj {
status
:
this
.
info
.
status
,
status
:
this
.
info
.
status
,
intermediateCount
:
this
.
intermediates
.
length
,
intermediateCount
:
this
.
intermediates
.
length
,
accuracy
:
this
.
finalAcc
,
accuracy
:
this
.
finalAcc
,
latestAccuracy
:
this
.
formatLatestAccuracy
(),
latestAccuracy
:
this
.
latestAccuracy
,
formattedLatestAccuracy
:
this
.
formatLatestAccuracy
(),
};
};
}
}
...
...
src/webui/yarn.lock
View file @
cd3a912a
This source diff could not be displayed because it is too large. You can
view the blob
instead.
test/async_sharing_test/main.py
View file @
cd3a912a
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
"""
Test code for weight sharing
Test code for weight sharing
need NFS setup and mounted as `/mnt/nfs/nni`
need NFS setup and mounted as `/mnt/nfs/nni`
...
...
test/async_sharing_test/simple_tuner.py
View file @
cd3a912a
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
"""
SimpleTuner for Weight Sharing
SimpleTuner for Weight Sharing
"""
"""
...
...
test/cli_test.py
View file @
cd3a912a
# Copyright (c) Microsoft Corporation
# Copyright (c) Microsoft Corporation.
# All rights reserved.
# Licensed under the MIT license.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import
sys
import
sys
import
time
import
time
...
@@ -26,7 +9,7 @@ from utils import GREEN, RED, CLEAR, setup_experiment
...
@@ -26,7 +9,7 @@ from utils import GREEN, RED, CLEAR, setup_experiment
def
test_nni_cli
():
def
test_nni_cli
():
import
nnicli
as
nc
import
nnicli
as
nc
config_file
=
'config_test/examples/mnist.test.yml'
config_file
=
'config_test/examples/mnist
-tfv1
.test.yml'
try
:
try
:
# Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict
# Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict
...
...
test/config_test.py
View file @
cd3a912a
# Copyright (c) Microsoft Corporation
# Copyright (c) Microsoft Corporation.
# All rights reserved.
# Licensed under the MIT license.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import
os
import
os
import
argparse
import
argparse
...
...
Prev
1
…
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment