Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c785655e
Unverified
Commit
c785655e
authored
Oct 21, 2019
by
SparkSnail
Committed by
GitHub
Oct 21, 2019
Browse files
Merge pull request #207 from microsoft/master
merge master
parents
9fae194a
d6b61e2f
Changes
158
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1489 additions
and
1277 deletions
+1489
-1277
src/sdk/pynni/nni/ppo_tuner/ppo_tuner.py
src/sdk/pynni/nni/ppo_tuner/ppo_tuner.py
+598
-0
src/sdk/pynni/nni/ppo_tuner/requirements.txt
src/sdk/pynni/nni/ppo_tuner/requirements.txt
+3
-0
src/sdk/pynni/nni/ppo_tuner/util.py
src/sdk/pynni/nni/ppo_tuner/util.py
+266
-0
src/sdk/pynni/nni/trial.py
src/sdk/pynni/nni/trial.py
+2
-1
src/sdk/pynni/nni/tuner.py
src/sdk/pynni/nni/tuner.py
+1
-2
src/sdk/pynni/tests/test_assessor.py
src/sdk/pynni/tests/test_assessor.py
+4
-2
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+116
-0
src/sdk/pynni/tests/test_tuner.py
src/sdk/pynni/tests/test_tuner.py
+7
-7
src/webui/src/App.tsx
src/webui/src/App.tsx
+87
-80
src/webui/src/components/Modal/Compare.tsx
src/webui/src/components/Modal/Compare.tsx
+10
-7
src/webui/src/components/Modal/ExperimentDrawer.tsx
src/webui/src/components/Modal/ExperimentDrawer.tsx
+1
-1
src/webui/src/components/Modal/LogDrawer.tsx
src/webui/src/components/Modal/LogDrawer.tsx
+2
-2
src/webui/src/components/Overview.tsx
src/webui/src/components/Overview.tsx
+96
-456
src/webui/src/components/SlideBar.tsx
src/webui/src/components/SlideBar.tsx
+12
-36
src/webui/src/components/TrialsDetail.tsx
src/webui/src/components/TrialsDetail.tsx
+50
-339
src/webui/src/components/overview/BasicInfo.tsx
src/webui/src/components/overview/BasicInfo.tsx
+12
-35
src/webui/src/components/overview/NumInput.tsx
src/webui/src/components/overview/NumInput.tsx
+85
-0
src/webui/src/components/overview/Progress.tsx
src/webui/src/components/overview/Progress.tsx
+66
-181
src/webui/src/components/overview/SuccessTable.tsx
src/webui/src/components/overview/SuccessTable.tsx
+54
-102
src/webui/src/components/overview/TrialProfile.tsx
src/webui/src/components/overview/TrialProfile.tsx
+17
-26
No files found.
src/sdk/pynni/nni/ppo_tuner/ppo_tuner.py
0 → 100644
View file @
c785655e
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
ppo_tuner.py including:
class PPOTuner
"""
import
os
import
copy
import
logging
import
numpy
as
np
import
json_tricks
from
gym
import
spaces
import
nni
from
nni.tuner
import
Tuner
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
from
.model
import
Model
from
.util
import
set_global_seeds
from
.policy
import
build_lstm_policy
logger
=
logging
.
getLogger
(
'ppo_tuner_AutoML'
)
def
constfn
(
val
):
"""wrap as function"""
def
f
(
_
):
return
val
return
f
class
ModelConfig
:
"""
Configurations of the PPO model
"""
def
__init__
(
self
):
self
.
observation_space
=
None
self
.
action_space
=
None
self
.
num_envs
=
0
self
.
nsteps
=
0
self
.
ent_coef
=
0.0
self
.
lr
=
3e-4
self
.
vf_coef
=
0.5
self
.
max_grad_norm
=
0.5
self
.
gamma
=
0.99
self
.
lam
=
0.95
self
.
cliprange
=
0.2
self
.
embedding_size
=
None
# the embedding is for each action
self
.
noptepochs
=
4
# number of training epochs per update
self
.
total_timesteps
=
5000
# number of timesteps (i.e. number of actions taken in the environment)
self
.
nminibatches
=
4
# number of training minibatches per update. For recurrent policies,
# should be smaller or equal than number of environments run in parallel.
class
TrialsInfo
:
"""
Informations of each trial from one model inference
"""
def
__init__
(
self
,
obs
,
actions
,
values
,
neglogpacs
,
dones
,
last_value
,
inf_batch_size
):
self
.
iter
=
0
self
.
obs
=
obs
self
.
actions
=
actions
self
.
values
=
values
self
.
neglogpacs
=
neglogpacs
self
.
dones
=
dones
self
.
last_value
=
last_value
self
.
rewards
=
None
self
.
returns
=
None
self
.
inf_batch_size
=
inf_batch_size
#self.states = None
def
get_next
(
self
):
"""
get actions of the next trial
"""
if
self
.
iter
>=
self
.
inf_batch_size
:
return
None
,
None
actions
=
[]
for
step
in
self
.
actions
:
actions
.
append
(
step
[
self
.
iter
])
self
.
iter
+=
1
return
self
.
iter
-
1
,
actions
def
update_rewards
(
self
,
rewards
,
returns
):
"""
after the trial is finished, reward and return of this trial is updated
"""
self
.
rewards
=
rewards
self
.
returns
=
returns
def
convert_shape
(
self
):
"""
convert shape
"""
def
sf01
(
arr
):
"""
swap and then flatten axes 0 and 1
"""
s
=
arr
.
shape
return
arr
.
swapaxes
(
0
,
1
).
reshape
(
s
[
0
]
*
s
[
1
],
*
s
[
2
:])
self
.
obs
=
sf01
(
self
.
obs
)
self
.
returns
=
sf01
(
self
.
returns
)
self
.
dones
=
sf01
(
self
.
dones
)
self
.
actions
=
sf01
(
self
.
actions
)
self
.
values
=
sf01
(
self
.
values
)
self
.
neglogpacs
=
sf01
(
self
.
neglogpacs
)
class
PPOModel
:
"""
PPO Model
"""
def
__init__
(
self
,
model_config
,
mask
):
self
.
model_config
=
model_config
self
.
states
=
None
# initial state of lstm in policy/value network
self
.
nupdates
=
None
# the number of func train is invoked, used to tune lr and cliprange
self
.
cur_update
=
1
# record the current update
self
.
np_mask
=
mask
# record the mask of each action within one trial
set_global_seeds
(
None
)
assert
isinstance
(
self
.
model_config
.
lr
,
float
)
self
.
lr
=
constfn
(
self
.
model_config
.
lr
)
assert
isinstance
(
self
.
model_config
.
cliprange
,
float
)
self
.
cliprange
=
constfn
(
self
.
model_config
.
cliprange
)
# build lstm policy network, value share the same network
policy
=
build_lstm_policy
(
model_config
)
# Get the nb of env
nenvs
=
model_config
.
num_envs
# Calculate the batch_size
self
.
nbatch
=
nbatch
=
nenvs
*
model_config
.
nsteps
# num of record per update
nbatch_train
=
nbatch
//
model_config
.
nminibatches
# get batch size
# self.nupdates is used to tune lr and cliprange
self
.
nupdates
=
self
.
model_config
.
total_timesteps
//
self
.
nbatch
# Instantiate the model object (that creates act_model and train_model)
self
.
model
=
Model
(
policy
=
policy
,
nbatch_act
=
nenvs
,
nbatch_train
=
nbatch_train
,
nsteps
=
model_config
.
nsteps
,
ent_coef
=
model_config
.
ent_coef
,
vf_coef
=
model_config
.
vf_coef
,
max_grad_norm
=
model_config
.
max_grad_norm
,
np_mask
=
self
.
np_mask
)
self
.
states
=
self
.
model
.
initial_state
logger
.
info
(
'=== finished PPOModel initialization'
)
def
inference
(
self
,
num
):
"""
generate actions along with related info from policy network.
observation is the action of the last step.
Parameters:
----------
num: the number of trials to generate
"""
# Here, we init the lists that will contain the mb of experiences
mb_obs
,
mb_actions
,
mb_values
,
mb_dones
,
mb_neglogpacs
=
[],
[],
[],
[],
[]
# initial observation
# use the (n+1)th embedding to represent the first step action
first_step_ob
=
self
.
model_config
.
action_space
.
n
obs
=
[
first_step_ob
for
_
in
range
(
num
)]
dones
=
[
True
for
_
in
range
(
num
)]
states
=
self
.
states
# For n in range number of steps
for
cur_step
in
range
(
self
.
model_config
.
nsteps
):
# Given observations, get action value and neglopacs
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
actions
,
values
,
states
,
neglogpacs
=
self
.
model
.
step
(
cur_step
,
obs
,
S
=
states
,
M
=
dones
)
mb_obs
.
append
(
obs
.
copy
())
mb_actions
.
append
(
actions
)
mb_values
.
append
(
values
)
mb_neglogpacs
.
append
(
neglogpacs
)
mb_dones
.
append
(
dones
)
# Take actions in env and look the results
# Infos contains a ton of useful informations
obs
[:]
=
actions
if
cur_step
==
self
.
model_config
.
nsteps
-
1
:
dones
=
[
True
for
_
in
range
(
num
)]
else
:
dones
=
[
False
for
_
in
range
(
num
)]
#batch of steps to batch of rollouts
np_obs
=
np
.
asarray
(
obs
)
mb_obs
=
np
.
asarray
(
mb_obs
,
dtype
=
np_obs
.
dtype
)
mb_actions
=
np
.
asarray
(
mb_actions
)
mb_values
=
np
.
asarray
(
mb_values
,
dtype
=
np
.
float32
)
mb_neglogpacs
=
np
.
asarray
(
mb_neglogpacs
,
dtype
=
np
.
float32
)
mb_dones
=
np
.
asarray
(
mb_dones
,
dtype
=
np
.
bool
)
last_values
=
self
.
model
.
value
(
np_obs
,
S
=
states
,
M
=
dones
)
return
mb_obs
,
mb_actions
,
mb_values
,
mb_neglogpacs
,
mb_dones
,
last_values
def
compute_rewards
(
self
,
trials_info
,
trials_result
):
"""
compute the rewards of the trials in trials_info based on trials_result,
and update the rewards in trials_info
Parameters:
----------
trials_info: info of the generated trials
trials_result: final results (e.g., acc) of the generated trials
"""
mb_rewards
=
np
.
asarray
([
trials_result
for
_
in
trials_info
.
actions
],
dtype
=
np
.
float32
)
# discount/bootstrap off value fn
mb_returns
=
np
.
zeros_like
(
mb_rewards
)
mb_advs
=
np
.
zeros_like
(
mb_rewards
)
lastgaelam
=
0
last_dones
=
np
.
asarray
([
True
for
_
in
trials_result
],
dtype
=
np
.
bool
)
# ugly
for
t
in
reversed
(
range
(
self
.
model_config
.
nsteps
)):
if
t
==
self
.
model_config
.
nsteps
-
1
:
nextnonterminal
=
1.0
-
last_dones
nextvalues
=
trials_info
.
last_value
else
:
nextnonterminal
=
1.0
-
trials_info
.
dones
[
t
+
1
]
nextvalues
=
trials_info
.
values
[
t
+
1
]
delta
=
mb_rewards
[
t
]
+
self
.
model_config
.
gamma
*
nextvalues
*
nextnonterminal
-
trials_info
.
values
[
t
]
mb_advs
[
t
]
=
lastgaelam
=
delta
+
self
.
model_config
.
gamma
*
self
.
model_config
.
lam
*
nextnonterminal
*
lastgaelam
mb_returns
=
mb_advs
+
trials_info
.
values
trials_info
.
update_rewards
(
mb_rewards
,
mb_returns
)
trials_info
.
convert_shape
()
def
train
(
self
,
trials_info
,
nenvs
):
"""
train the policy/value network using trials_info
Parameters:
----------
trials_info: complete info of the generated trials from the previous inference
nenvs: the batch size of the (previous) inference
"""
# keep frac decay for future optimization
if
self
.
cur_update
<=
self
.
nupdates
:
frac
=
1.0
-
(
self
.
cur_update
-
1.0
)
/
self
.
nupdates
else
:
logger
.
warning
(
'current update (self.cur_update) %d has exceeded total updates (self.nupdates) %d'
,
self
.
cur_update
,
self
.
nupdates
)
frac
=
1.0
-
(
self
.
nupdates
-
1.0
)
/
self
.
nupdates
lrnow
=
self
.
lr
(
frac
)
cliprangenow
=
self
.
cliprange
(
frac
)
self
.
cur_update
+=
1
states
=
self
.
states
assert
states
is
not
None
# recurrent version
assert
nenvs
%
self
.
model_config
.
nminibatches
==
0
envsperbatch
=
nenvs
//
self
.
model_config
.
nminibatches
envinds
=
np
.
arange
(
nenvs
)
flatinds
=
np
.
arange
(
nenvs
*
self
.
model_config
.
nsteps
).
reshape
(
nenvs
,
self
.
model_config
.
nsteps
)
for
_
in
range
(
self
.
model_config
.
noptepochs
):
np
.
random
.
shuffle
(
envinds
)
for
start
in
range
(
0
,
nenvs
,
envsperbatch
):
end
=
start
+
envsperbatch
mbenvinds
=
envinds
[
start
:
end
]
mbflatinds
=
flatinds
[
mbenvinds
].
ravel
()
slices
=
(
arr
[
mbflatinds
]
for
arr
in
(
trials_info
.
obs
,
trials_info
.
returns
,
trials_info
.
dones
,
trials_info
.
actions
,
trials_info
.
values
,
trials_info
.
neglogpacs
))
mbstates
=
states
[
mbenvinds
]
self
.
model
.
train
(
lrnow
,
cliprangenow
,
*
slices
,
mbstates
)
class
PPOTuner
(
Tuner
):
"""
PPOTuner
"""
def
__init__
(
self
,
optimize_mode
,
trials_per_update
=
20
,
epochs_per_update
=
4
,
minibatch_size
=
4
,
ent_coef
=
0.0
,
lr
=
3e-4
,
vf_coef
=
0.5
,
max_grad_norm
=
0.5
,
gamma
=
0.99
,
lam
=
0.95
,
cliprange
=
0.2
):
"""
initialization, PPO model is not initialized here as search space is not received yet.
Parameters:
----------
optimize_mode: maximize or minimize
trials_per_update: number of trials to have for each model update
epochs_per_update: number of epochs to run for each model update
minibatch_size: minibatch size (number of trials) for the update
ent_coef: policy entropy coefficient in the optimization objective
lr: learning rate of the model (lstm network), constant
vf_coef: value function loss coefficient in the optimization objective
max_grad_norm: gradient norm clipping coefficient
gamma: discounting factor
lam: advantage estimation discounting factor (lambda in the paper)
cliprange: cliprange in the PPO algorithm, constant
"""
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
self
.
model_config
=
ModelConfig
()
self
.
model
=
None
self
.
search_space
=
None
self
.
running_trials
=
{}
# key: parameter_id, value: actions/states/etc.
self
.
inf_batch_size
=
trials_per_update
# number of trials to generate in one inference
self
.
first_inf
=
True
# indicate whether it is the first time to inference new trials
self
.
trials_result
=
[
None
for
_
in
range
(
self
.
inf_batch_size
)]
# results of finished trials
self
.
credit
=
0
# record the unsatisfied trial requests
self
.
param_ids
=
[]
self
.
finished_trials
=
0
self
.
chosen_arch_template
=
{}
self
.
actions_spaces
=
None
self
.
actions_to_config
=
None
self
.
full_act_space
=
None
self
.
trials_info
=
None
self
.
all_trials
=
{}
# used to dedup the same trial, key: config, value: final result
self
.
model_config
.
num_envs
=
self
.
inf_batch_size
self
.
model_config
.
noptepochs
=
epochs_per_update
self
.
model_config
.
nminibatches
=
minibatch_size
self
.
send_trial_callback
=
None
logger
.
info
(
'=== finished PPOTuner initialization'
)
def
_process_one_nas_space
(
self
,
block_name
,
block_space
):
"""
process nas space to determine observation space and action space
Parameters:
----------
block_name: the name of the mutable block
block_space: search space of this mutable block
Returns:
----------
actions_spaces: list of the space of each action
actions_to_config: the mapping from action to generated configuration
"""
actions_spaces
=
[]
actions_to_config
=
[]
block_arch_temp
=
{}
for
l_name
,
layer
in
block_space
.
items
():
chosen_layer_temp
=
{}
if
len
(
layer
[
'layer_choice'
])
>
1
:
actions_spaces
.
append
(
layer
[
'layer_choice'
])
actions_to_config
.
append
((
block_name
,
l_name
,
'chosen_layer'
))
chosen_layer_temp
[
'chosen_layer'
]
=
None
else
:
assert
len
(
layer
[
'layer_choice'
])
==
1
chosen_layer_temp
[
'chosen_layer'
]
=
layer
[
'layer_choice'
][
0
]
if
layer
[
'optional_input_size'
]
not
in
[
0
,
1
,
[
0
,
1
]]:
raise
ValueError
(
'Optional_input_size can only be 0, 1, or [0, 1], but the pecified one is %s'
%
(
layer
[
'optional_input_size'
]))
if
isinstance
(
layer
[
'optional_input_size'
],
list
):
actions_spaces
.
append
([
"None"
,
*
layer
[
'optional_inputs'
]])
actions_to_config
.
append
((
block_name
,
l_name
,
'chosen_inputs'
))
chosen_layer_temp
[
'chosen_inputs'
]
=
None
elif
layer
[
'optional_input_size'
]
==
1
:
actions_spaces
.
append
(
layer
[
'optional_inputs'
])
actions_to_config
.
append
((
block_name
,
l_name
,
'chosen_inputs'
))
chosen_layer_temp
[
'chosen_inputs'
]
=
None
elif
layer
[
'optional_input_size'
]
==
0
:
chosen_layer_temp
[
'chosen_inputs'
]
=
[]
else
:
raise
ValueError
(
'invalid type and value of optional_input_size'
)
block_arch_temp
[
l_name
]
=
chosen_layer_temp
self
.
chosen_arch_template
[
block_name
]
=
block_arch_temp
return
actions_spaces
,
actions_to_config
def
_process_nas_space
(
self
,
search_space
):
"""
process nas search space to get action/observation space
"""
actions_spaces
=
[]
actions_to_config
=
[]
for
b_name
,
block
in
search_space
.
items
():
if
block
[
'_type'
]
!=
'mutable_layer'
:
raise
ValueError
(
'PPOTuner only accept mutable_layer type in search space, but the current one is %s'
%
(
block
[
'_type'
]))
block
=
block
[
'_value'
]
act
,
act_map
=
self
.
_process_one_nas_space
(
b_name
,
block
)
actions_spaces
.
extend
(
act
)
actions_to_config
.
extend
(
act_map
)
# calculate observation space
dedup
=
{}
for
step
in
actions_spaces
:
for
action
in
step
:
dedup
[
action
]
=
1
full_act_space
=
[
act
for
act
,
_
in
dedup
.
items
()]
assert
len
(
full_act_space
)
==
len
(
dedup
)
observation_space
=
len
(
full_act_space
)
nsteps
=
len
(
actions_spaces
)
return
actions_spaces
,
actions_to_config
,
full_act_space
,
observation_space
,
nsteps
def
_generate_action_mask
(
self
):
"""
different step could have different action space. to deal with this case, we merge all the
possible actions into one action space, and use mask to indicate available actions for each step
"""
two_masks
=
[]
mask
=
[]
for
acts
in
self
.
actions_spaces
:
one_mask
=
[
0
for
_
in
range
(
len
(
self
.
full_act_space
))]
for
act
in
acts
:
idx
=
self
.
full_act_space
.
index
(
act
)
one_mask
[
idx
]
=
1
mask
.
append
(
one_mask
)
two_masks
.
append
(
mask
)
mask
=
[]
for
acts
in
self
.
actions_spaces
:
one_mask
=
[
-
np
.
inf
for
_
in
range
(
len
(
self
.
full_act_space
))]
for
act
in
acts
:
idx
=
self
.
full_act_space
.
index
(
act
)
one_mask
[
idx
]
=
0
mask
.
append
(
one_mask
)
two_masks
.
append
(
mask
)
return
np
.
asarray
(
two_masks
,
dtype
=
np
.
float32
)
def
update_search_space
(
self
,
search_space
):
"""
get search space, currently the space only includes that for NAS
Parameters:
----------
search_space: search space for NAS
Returns:
-------
no return
"""
logger
.
info
(
'=== update search space %s'
,
search_space
)
assert
self
.
search_space
is
None
self
.
search_space
=
search_space
assert
self
.
model_config
.
observation_space
is
None
assert
self
.
model_config
.
action_space
is
None
self
.
actions_spaces
,
self
.
actions_to_config
,
self
.
full_act_space
,
obs_space
,
nsteps
=
self
.
_process_nas_space
(
search_space
)
self
.
model_config
.
observation_space
=
spaces
.
Discrete
(
obs_space
)
self
.
model_config
.
action_space
=
spaces
.
Discrete
(
obs_space
)
self
.
model_config
.
nsteps
=
nsteps
# generate mask in numpy
mask
=
self
.
_generate_action_mask
()
assert
self
.
model
is
None
self
.
model
=
PPOModel
(
self
.
model_config
,
mask
)
def
_actions_to_config
(
self
,
actions
):
"""
given actions, to generate the corresponding trial configuration
"""
chosen_arch
=
copy
.
deepcopy
(
self
.
chosen_arch_template
)
for
cnt
,
act
in
enumerate
(
actions
):
act_name
=
self
.
full_act_space
[
act
]
(
block_name
,
layer_name
,
key
)
=
self
.
actions_to_config
[
cnt
]
if
key
==
'chosen_inputs'
:
if
act_name
==
'None'
:
chosen_arch
[
block_name
][
layer_name
][
key
]
=
[]
else
:
chosen_arch
[
block_name
][
layer_name
][
key
]
=
[
act_name
]
elif
key
==
'chosen_layer'
:
chosen_arch
[
block_name
][
layer_name
][
key
]
=
act_name
else
:
raise
ValueError
(
'unrecognized key: {0}'
.
format
(
key
))
return
chosen_arch
def
generate_multiple_parameters
(
self
,
parameter_id_list
,
**
kwargs
):
"""
Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
"""
result
=
[]
self
.
send_trial_callback
=
kwargs
[
'st_callback'
]
for
parameter_id
in
parameter_id_list
:
had_exception
=
False
try
:
logger
.
debug
(
"generating param for %s"
,
parameter_id
)
res
=
self
.
generate_parameters
(
parameter_id
,
**
kwargs
)
except
nni
.
NoMoreTrialError
:
had_exception
=
True
if
not
had_exception
:
result
.
append
(
res
)
return
result
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""
generate parameters, if no trial configration for now, self.credit plus 1 to send the config later
"""
if
self
.
first_inf
:
self
.
trials_result
=
[
None
for
_
in
range
(
self
.
inf_batch_size
)]
mb_obs
,
mb_actions
,
mb_values
,
mb_neglogpacs
,
mb_dones
,
last_values
=
self
.
model
.
inference
(
self
.
inf_batch_size
)
self
.
trials_info
=
TrialsInfo
(
mb_obs
,
mb_actions
,
mb_values
,
mb_neglogpacs
,
mb_dones
,
last_values
,
self
.
inf_batch_size
)
self
.
first_inf
=
False
trial_info_idx
,
actions
=
self
.
trials_info
.
get_next
()
if
trial_info_idx
is
None
:
self
.
credit
+=
1
self
.
param_ids
.
append
(
parameter_id
)
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
self
.
running_trials
[
parameter_id
]
=
trial_info_idx
new_config
=
self
.
_actions_to_config
(
actions
)
return
new_config
def
_next_round_inference
(
self
):
"""
"""
self
.
finished_trials
=
0
self
.
model
.
compute_rewards
(
self
.
trials_info
,
self
.
trials_result
)
self
.
model
.
train
(
self
.
trials_info
,
self
.
inf_batch_size
)
self
.
running_trials
=
{}
# generate new trials
self
.
trials_result
=
[
None
for
_
in
range
(
self
.
inf_batch_size
)]
mb_obs
,
mb_actions
,
mb_values
,
mb_neglogpacs
,
mb_dones
,
last_values
=
self
.
model
.
inference
(
self
.
inf_batch_size
)
self
.
trials_info
=
TrialsInfo
(
mb_obs
,
mb_actions
,
mb_values
,
mb_neglogpacs
,
mb_dones
,
last_values
,
self
.
inf_batch_size
)
# check credit and submit new trials
for
_
in
range
(
self
.
credit
):
trial_info_idx
,
actions
=
self
.
trials_info
.
get_next
()
if
trial_info_idx
is
None
:
logger
.
warning
(
'No enough trial config, trials_per_update is suggested to be larger than trialConcurrency'
)
break
assert
self
.
param_ids
param_id
=
self
.
param_ids
.
pop
()
self
.
running_trials
[
param_id
]
=
trial_info_idx
new_config
=
self
.
_actions_to_config
(
actions
)
self
.
send_trial_callback
(
param_id
,
new_config
)
self
.
credit
-=
1
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""
receive trial's result. if the number of finished trials equals self.inf_batch_size, start the next update to
train the model
"""
trial_info_idx
=
self
.
running_trials
.
pop
(
parameter_id
,
None
)
assert
trial_info_idx
is
not
None
value
=
extract_scalar_reward
(
value
)
if
self
.
optimize_mode
==
OptimizeMode
.
Minimize
:
value
=
-
value
self
.
trials_result
[
trial_info_idx
]
=
value
self
.
finished_trials
+=
1
if
self
.
finished_trials
==
self
.
inf_batch_size
:
self
.
_next_round_inference
()
def
trial_end
(
self
,
parameter_id
,
success
,
**
kwargs
):
"""
to deal with trial failure
"""
if
not
success
:
if
parameter_id
not
in
self
.
running_trials
:
logger
.
warning
(
'The trial is failed, but self.running_trial does not have this trial'
)
return
trial_info_idx
=
self
.
running_trials
.
pop
(
parameter_id
,
None
)
assert
trial_info_idx
is
not
None
# use mean of finished trials as the result of this failed trial
values
=
[
val
for
val
in
self
.
trials_result
if
val
is
not
None
]
logger
.
warning
(
'zql values: {0}'
.
format
(
values
))
self
.
trials_result
[
trial_info_idx
]
=
(
sum
(
values
)
/
len
(
values
))
if
len
(
values
)
>
0
else
0
self
.
finished_trials
+=
1
if
self
.
finished_trials
==
self
.
inf_batch_size
:
self
.
_next_round_inference
()
def
import_data
(
self
,
data
):
"""
Import additional data for tuning
Parameters
----------
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
logger
.
warning
(
'PPOTuner cannot leverage imported data.'
)
src/sdk/pynni/nni/ppo_tuner/requirements.txt
0 → 100644
View file @
c785655e
enum34
gym
tensorflow
\ No newline at end of file
src/sdk/pynni/nni/ppo_tuner/util.py
0 → 100644
View file @
c785655e
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
util functions
"""
import
os
import
random
import
multiprocessing
import
numpy
as
np
import
tensorflow
as
tf
from
gym.spaces
import
Discrete
,
Box
,
MultiDiscrete
def
set_global_seeds
(
i
):
"""set global seeds"""
rank
=
0
myseed
=
i
+
1000
*
rank
if
i
is
not
None
else
None
tf
.
set_random_seed
(
myseed
)
np
.
random
.
seed
(
myseed
)
random
.
seed
(
myseed
)
def
batch_to_seq
(
h
,
nbatch
,
nsteps
,
flat
=
False
):
"""convert from batch to sequence"""
if
flat
:
h
=
tf
.
reshape
(
h
,
[
nbatch
,
nsteps
])
else
:
h
=
tf
.
reshape
(
h
,
[
nbatch
,
nsteps
,
-
1
])
return
[
tf
.
squeeze
(
v
,
[
1
])
for
v
in
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
nsteps
,
value
=
h
)]
def
seq_to_batch
(
h
,
flat
=
False
):
"""convert from sequence to batch"""
shape
=
h
[
0
].
get_shape
().
as_list
()
if
not
flat
:
assert
len
(
shape
)
>
1
nh
=
h
[
0
].
get_shape
()[
-
1
].
value
return
tf
.
reshape
(
tf
.
concat
(
axis
=
1
,
values
=
h
),
[
-
1
,
nh
])
else
:
return
tf
.
reshape
(
tf
.
stack
(
values
=
h
,
axis
=
1
),
[
-
1
])
def
lstm
(
xs
,
ms
,
s
,
scope
,
nh
,
init_scale
=
1.0
):
"""lstm cell"""
nbatch
,
nin
=
[
v
.
value
for
v
in
xs
[
0
].
get_shape
()]
with
tf
.
variable_scope
(
scope
):
wx
=
tf
.
get_variable
(
"wx"
,
[
nin
,
nh
*
4
],
initializer
=
ortho_init
(
init_scale
))
wh
=
tf
.
get_variable
(
"wh"
,
[
nh
,
nh
*
4
],
initializer
=
ortho_init
(
init_scale
))
b
=
tf
.
get_variable
(
"b"
,
[
nh
*
4
],
initializer
=
tf
.
constant_initializer
(
0.0
))
c
,
h
=
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
2
,
value
=
s
)
for
idx
,
(
x
,
m
)
in
enumerate
(
zip
(
xs
,
ms
)):
c
=
c
*
(
1
-
m
)
h
=
h
*
(
1
-
m
)
z
=
tf
.
matmul
(
x
,
wx
)
+
tf
.
matmul
(
h
,
wh
)
+
b
i
,
f
,
o
,
u
=
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
4
,
value
=
z
)
i
=
tf
.
nn
.
sigmoid
(
i
)
f
=
tf
.
nn
.
sigmoid
(
f
)
o
=
tf
.
nn
.
sigmoid
(
o
)
u
=
tf
.
tanh
(
u
)
c
=
f
*
c
+
i
*
u
h
=
o
*
tf
.
tanh
(
c
)
xs
[
idx
]
=
h
s
=
tf
.
concat
(
axis
=
1
,
values
=
[
c
,
h
])
return
xs
,
s
def
lstm_model
(
nlstm
=
128
,
layer_norm
=
False
):
"""
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
Note that the resulting function returns not only the output of the LSTM
(i.e. hidden state of lstm for each step in the sequence), but also a dictionary
with auxiliary tensors to be set as policy attributes.
Specifically,
S is a placeholder to feed current state (LSTM state has to be managed outside policy)
M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
initial_state is a numpy array containing initial lstm state (usually zeros)
state is the output LSTM state (to be fed into S at the next call)
An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example
Parameters:
----------
nlstm: int LSTM hidden state size
layer_norm: bool if True, layer-normalized version of LSTM is used
Returns:
-------
function that builds LSTM with a given input tensor / placeholder
"""
def
network_fn
(
X
,
nenv
=
1
,
obs_size
=-
1
):
with
tf
.
variable_scope
(
"emb"
,
reuse
=
tf
.
AUTO_REUSE
):
w_emb
=
tf
.
get_variable
(
"w_emb"
,
[
obs_size
+
1
,
32
])
X
=
tf
.
nn
.
embedding_lookup
(
w_emb
,
X
)
nbatch
=
X
.
shape
[
0
]
nsteps
=
nbatch
//
nenv
h
=
tf
.
layers
.
flatten
(
X
)
M
=
tf
.
placeholder
(
tf
.
float32
,
[
nbatch
])
#mask (done t-1)
S
=
tf
.
placeholder
(
tf
.
float32
,
[
nenv
,
2
*
nlstm
])
#states
xs
=
batch_to_seq
(
h
,
nenv
,
nsteps
)
ms
=
batch_to_seq
(
M
,
nenv
,
nsteps
)
assert
not
layer_norm
h5
,
snew
=
lstm
(
xs
,
ms
,
S
,
scope
=
'lstm'
,
nh
=
nlstm
)
h
=
seq_to_batch
(
h5
)
initial_state
=
np
.
zeros
(
S
.
shape
.
as_list
(),
dtype
=
float
)
return
h
,
{
'S'
:
S
,
'M'
:
M
,
'state'
:
snew
,
'initial_state'
:
initial_state
}
return
network_fn
def
ortho_init
(
scale
=
1.0
):
"""init approach"""
def
_ortho_init
(
shape
,
dtype
,
partition_info
=
None
):
#lasagne ortho init for tf
shape
=
tuple
(
shape
)
if
len
(
shape
)
==
2
:
flat_shape
=
shape
elif
len
(
shape
)
==
4
:
# assumes NHWC
flat_shape
=
(
np
.
prod
(
shape
[:
-
1
]),
shape
[
-
1
])
else
:
raise
NotImplementedError
a
=
np
.
random
.
normal
(
0.0
,
1.0
,
flat_shape
)
u
,
_
,
v
=
np
.
linalg
.
svd
(
a
,
full_matrices
=
False
)
q
=
u
if
u
.
shape
==
flat_shape
else
v
# pick the one with the correct shape
q
=
q
.
reshape
(
shape
)
return
(
scale
*
q
[:
shape
[
0
],
:
shape
[
1
]]).
astype
(
np
.
float32
)
return
_ortho_init
def
fc
(
x
,
scope
,
nh
,
*
,
init_scale
=
1.0
,
init_bias
=
0.0
):
"""fully connected op"""
with
tf
.
variable_scope
(
scope
):
nin
=
x
.
get_shape
()[
1
].
value
w
=
tf
.
get_variable
(
"w"
,
[
nin
,
nh
],
initializer
=
ortho_init
(
init_scale
))
b
=
tf
.
get_variable
(
"b"
,
[
nh
],
initializer
=
tf
.
constant_initializer
(
init_bias
))
return
tf
.
matmul
(
x
,
w
)
+
b
def
_check_shape
(
placeholder_shape
,
data_shape
):
"""
check if two shapes are compatible (i.e. differ only by dimensions of size 1, or by the batch dimension)
"""
return
True
# ================================================================
# Shape adjustment for feeding into tf placeholders
# ================================================================
def
adjust_shape
(
placeholder
,
data
):
"""
adjust shape of the data to the shape of the placeholder if possible.
If shape is incompatible, AssertionError is thrown
Parameters:
placeholder: tensorflow input placeholder
data: input data to be (potentially) reshaped to be fed into placeholder
Returns:
reshaped data
"""
if
not
isinstance
(
data
,
np
.
ndarray
)
and
not
isinstance
(
data
,
list
):
return
data
if
isinstance
(
data
,
list
):
data
=
np
.
array
(
data
)
placeholder_shape
=
[
x
or
-
1
for
x
in
placeholder
.
shape
.
as_list
()]
assert
_check_shape
(
placeholder_shape
,
data
.
shape
),
\
'Shape of data {} is not compatible with shape of the placeholder {}'
.
format
(
data
.
shape
,
placeholder_shape
)
return
np
.
reshape
(
data
,
placeholder_shape
)
# ================================================================
# Global session
# ================================================================
def
get_session
(
config
=
None
):
"""Get default session or create one with a given config"""
sess
=
tf
.
get_default_session
()
if
sess
is
None
:
sess
=
make_session
(
config
=
config
,
make_default
=
True
)
return
sess
def
make_session
(
config
=
None
,
num_cpu
=
None
,
make_default
=
False
,
graph
=
None
):
"""Returns a session that will use <num_cpu> CPU's only"""
if
num_cpu
is
None
:
num_cpu
=
int
(
os
.
getenv
(
'RCALL_NUM_CPU'
,
multiprocessing
.
cpu_count
()))
if
config
is
None
:
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
,
inter_op_parallelism_threads
=
num_cpu
,
intra_op_parallelism_threads
=
num_cpu
)
config
.
gpu_options
.
allow_growth
=
True
if
make_default
:
return
tf
.
InteractiveSession
(
config
=
config
,
graph
=
graph
)
else
:
return
tf
.
Session
(
config
=
config
,
graph
=
graph
)
ALREADY_INITIALIZED
=
set
()
def
initialize
():
"""Initialize all the uninitialized variables in the global scope."""
new_variables
=
set
(
tf
.
global_variables
())
-
ALREADY_INITIALIZED
get_session
().
run
(
tf
.
variables_initializer
(
new_variables
))
ALREADY_INITIALIZED
.
update
(
new_variables
)
def
observation_placeholder
(
ob_space
,
batch_size
=
None
,
name
=
'Ob'
):
"""
Create placeholder to feed observations into of the size appropriate to the observation space
Parameters:
----------
ob_space: gym.Space observation space
batch_size: int size of the batch to be fed into input. Can be left None in most cases.
name: str name of the placeholder
Returns:
-------
tensorflow placeholder tensor
"""
assert
isinstance
(
ob_space
,
(
Discrete
,
Box
,
MultiDiscrete
)),
\
'Can only deal with Discrete and Box observation spaces for now'
dtype
=
ob_space
.
dtype
if
dtype
==
np
.
int8
:
dtype
=
np
.
uint8
return
tf
.
placeholder
(
shape
=
(
batch_size
,)
+
ob_space
.
shape
,
dtype
=
dtype
,
name
=
name
)
def
explained_variance
(
ypred
,
y
):
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
interpretation:
ev=0 => might as well have predicted zero
ev=1 => perfect prediction
ev<0 => worse than just predicting zero
"""
assert
y
.
ndim
==
1
and
ypred
.
ndim
==
1
vary
=
np
.
var
(
y
)
return
np
.
nan
if
vary
==
0
else
1
-
np
.
var
(
y
-
ypred
)
/
vary
src/sdk/pynni/nni/trial.py
View file @
c785655e
...
...
@@ -43,7 +43,8 @@ _sequence_id = platform.get_sequence_id()
def
get_next_parameter
():
"""Returns a set of (hyper-)paremeters generated by Tuner."""
"""Returns a set of (hyper-)paremeters generated by Tuner.
Returns None if no more (hyper-)parameters can be generated by Tuner."""
global
_params
_params
=
platform
.
get_next_parameter
()
if
_params
is
None
:
...
...
src/sdk/pynni/nni/tuner.py
View file @
c785655e
...
...
@@ -17,11 +17,10 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
logging
import
nni
from
.recoverable
import
Recoverable
_logger
=
logging
.
getLogger
(
__name__
)
...
...
src/sdk/pynni/tests/test_assessor.py
View file @
c785655e
...
...
@@ -28,9 +28,9 @@ from io import BytesIO
import
json
from
unittest
import
TestCase
,
main
_trials
=
[]
_end_trials
=
[]
_trials
=
[
]
_end_trials
=
[
]
class
NaiveAssessor
(
Assessor
):
def
assess_trial
(
self
,
trial_job_id
,
trial_history
):
...
...
@@ -47,12 +47,14 @@ class NaiveAssessor(Assessor):
_in_buf
=
BytesIO
()
_out_buf
=
BytesIO
()
def
_reverse_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
nni
.
protocol
.
_out_file
=
_in_buf
nni
.
protocol
.
_in_file
=
_out_buf
def
_restore_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
...
...
src/sdk/pynni/tests/test_compressor.py
0 → 100644
View file @
c785655e
from
unittest
import
TestCase
,
main
import
tensorflow
as
tf
import
torch
import
torch.nn.functional
as
F
import
nni.compression.tensorflow
as
tf_compressor
import
nni.compression.torch
as
torch_compressor
def
weight_variable
(
shape
):
return
tf
.
Variable
(
tf
.
truncated_normal
(
shape
,
stddev
=
0.1
))
def
bias_variable
(
shape
):
return
tf
.
Variable
(
tf
.
constant
(
0.1
,
shape
=
shape
))
def
conv2d
(
x_input
,
w_matrix
):
return
tf
.
nn
.
conv2d
(
x_input
,
w_matrix
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
'SAME'
)
def
max_pool
(
x_input
,
pool_size
):
size
=
[
1
,
pool_size
,
pool_size
,
1
]
return
tf
.
nn
.
max_pool
(
x_input
,
ksize
=
size
,
strides
=
size
,
padding
=
'SAME'
)
class
TfMnist
:
def
__init__
(
self
):
images
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
784
],
name
=
'input_x'
)
labels
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
10
],
name
=
'input_y'
)
keep_prob
=
tf
.
placeholder
(
tf
.
float32
,
name
=
'keep_prob'
)
self
.
images
=
images
self
.
labels
=
labels
self
.
keep_prob
=
keep_prob
self
.
train_step
=
None
self
.
accuracy
=
None
self
.
w1
=
None
self
.
b1
=
None
self
.
fcw1
=
None
self
.
cross
=
None
with
tf
.
name_scope
(
'reshape'
):
x_image
=
tf
.
reshape
(
images
,
[
-
1
,
28
,
28
,
1
])
with
tf
.
name_scope
(
'conv1'
):
w_conv1
=
weight_variable
([
5
,
5
,
1
,
32
])
self
.
w1
=
w_conv1
b_conv1
=
bias_variable
([
32
])
self
.
b1
=
b_conv1
h_conv1
=
tf
.
nn
.
relu
(
conv2d
(
x_image
,
w_conv1
)
+
b_conv1
)
with
tf
.
name_scope
(
'pool1'
):
h_pool1
=
max_pool
(
h_conv1
,
2
)
with
tf
.
name_scope
(
'conv2'
):
w_conv2
=
weight_variable
([
5
,
5
,
32
,
64
])
b_conv2
=
bias_variable
([
64
])
h_conv2
=
tf
.
nn
.
relu
(
conv2d
(
h_pool1
,
w_conv2
)
+
b_conv2
)
with
tf
.
name_scope
(
'pool2'
):
h_pool2
=
max_pool
(
h_conv2
,
2
)
with
tf
.
name_scope
(
'fc1'
):
w_fc1
=
weight_variable
([
7
*
7
*
64
,
1024
])
self
.
fcw1
=
w_fc1
b_fc1
=
bias_variable
([
1024
])
h_pool2_flat
=
tf
.
reshape
(
h_pool2
,
[
-
1
,
7
*
7
*
64
])
h_fc1
=
tf
.
nn
.
relu
(
tf
.
matmul
(
h_pool2_flat
,
w_fc1
)
+
b_fc1
)
with
tf
.
name_scope
(
'dropout'
):
h_fc1_drop
=
tf
.
nn
.
dropout
(
h_fc1
,
0.5
)
with
tf
.
name_scope
(
'fc2'
):
w_fc2
=
weight_variable
([
1024
,
10
])
b_fc2
=
bias_variable
([
10
])
y_conv
=
tf
.
matmul
(
h_fc1_drop
,
w_fc2
)
+
b_fc2
with
tf
.
name_scope
(
'loss'
):
cross_entropy
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
=
labels
,
logits
=
y_conv
))
self
.
cross
=
cross_entropy
with
tf
.
name_scope
(
'adam_optimizer'
):
self
.
train_step
=
tf
.
train
.
AdamOptimizer
(
0.0001
).
minimize
(
cross_entropy
)
with
tf
.
name_scope
(
'accuracy'
):
correct_prediction
=
tf
.
equal
(
tf
.
argmax
(
y_conv
,
1
),
tf
.
argmax
(
labels
,
1
))
self
.
accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
correct_prediction
,
tf
.
float32
))
class
TorchMnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
conv2
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
class
CompressorTestCase
(
TestCase
):
def
test_tf_pruner
(
self
):
model
=
TfMnist
()
configure_list
=
[{
'sparsity'
:
0.8
,
'op_types'
:
'default'
}]
tf_compressor
.
LevelPruner
(
configure_list
).
compress_default_graph
()
def
test_tf_quantizer
(
self
):
model
=
TfMnist
()
tf_compressor
.
NaiveQuantizer
([{
'op_types'
:
'default'
}]).
compress_default_graph
()
def
test_torch_pruner
(
self
):
model
=
TorchMnist
()
configure_list
=
[{
'sparsity'
:
0.8
,
'op_types'
:
'default'
}]
torch_compressor
.
LevelPruner
(
configure_list
).
compress
(
model
)
def
test_torch_quantizer
(
self
):
model
=
TorchMnist
()
torch_compressor
.
NaiveQuantizer
([{
'op_types'
:
'default'
}]).
compress
(
model
)
if
__name__
==
'__main__'
:
main
()
src/sdk/pynni/tests/test_tuner.py
View file @
c785655e
...
...
@@ -32,7 +32,7 @@ from unittest import TestCase, main
class
NaiveTuner
(
Tuner
):
def
__init__
(
self
):
self
.
param
=
0
self
.
trial_results
=
[
]
self
.
trial_results
=
[]
self
.
search_space
=
None
self
.
accept_customized_trials
()
...
...
@@ -57,12 +57,14 @@ class NaiveTuner(Tuner):
_in_buf
=
BytesIO
()
_out_buf
=
BytesIO
()
def
_reverse_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
nni
.
protocol
.
_out_file
=
_in_buf
nni
.
protocol
.
_in_file
=
_out_buf
def
_restore_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
...
...
@@ -70,7 +72,6 @@ def _restore_io():
nni
.
protocol
.
_out_file
=
_out_buf
class
TunerTestCase
(
TestCase
):
def
test_tuner
(
self
):
_reverse_io
()
# now we are sending to Tuner's incoming stream
...
...
@@ -94,21 +95,20 @@ class TunerTestCase(TestCase):
self
.
assertEqual
(
e
.
args
[
0
],
'Unsupported command: CommandType.KillTrialJob'
)
_reverse_io
()
# now we are receiving from Tuner's outgoing stream
self
.
_assert_params
(
0
,
2
,
[
],
None
)
self
.
_assert_params
(
1
,
4
,
[
],
None
)
self
.
_assert_params
(
0
,
2
,
[],
None
)
self
.
_assert_params
(
1
,
4
,
[],
None
)
command
,
data
=
receive
()
# this one is customized
data
=
json
.
loads
(
data
)
self
.
assertIs
(
command
,
CommandType
.
NewTrialJob
)
self
.
assertEqual
(
data
[
'parameter_id'
],
2
)
self
.
assertEqual
(
data
[
'parameter_source'
],
'customized'
)
self
.
assertEqual
(
data
[
'parameters'
],
{
'param'
:
-
1
})
self
.
assertEqual
(
data
[
'parameters'
],
{
'param'
:
-
1
})
self
.
_assert_params
(
3
,
6
,
[[
1
,
4
,
11
,
False
],
[
2
,
-
1
,
22
,
True
]],
{
'name'
:
'SS0'
})
self
.
_assert_params
(
3
,
6
,
[[
1
,
4
,
11
,
False
],
[
2
,
-
1
,
22
,
True
]],
{
'name'
:
'SS0'
})
self
.
assertEqual
(
len
(
_out_buf
.
read
()),
0
)
# no more commands
def
_assert_params
(
self
,
parameter_id
,
param
,
trial_results
,
search_space
):
command
,
data
=
receive
()
self
.
assertIs
(
command
,
CommandType
.
NewTrialJob
)
...
...
src/webui/src/App.tsx
View file @
c785655e
import
*
as
React
from
'
react
'
;
import
{
Row
,
Col
}
from
'
antd
'
;
import
axios
from
'
axios
'
;
import
{
COLUMN
,
MANAGER_IP
}
from
'
./static/
const
'
;
import
{
COLUMN
}
from
'
./static/const
'
;
import
{
EXPERIMENT
,
TRIALS
}
from
'
./static/
datamodel
'
;
import
'
./App.css
'
;
import
SlideBar
from
'
./components/SlideBar
'
;
interface
AppState
{
interval
:
number
;
whichPageToFresh
:
string
;
columnList
:
Array
<
string
>
;
concurrency
:
number
;
interval
:
number
;
columnList
:
Array
<
string
>
;
experimentUpdateBroadcast
:
number
;
trialsUpdateBroadcast
:
number
;
}
class
App
extends
React
.
Component
<
{},
AppState
>
{
public
_isMounted
:
boolean
;
constructor
(
props
:
{})
{
super
(
props
);
this
.
state
=
{
interval
:
10
,
// sendons
whichPageToFresh
:
''
,
columnList
:
COLUMN
,
concurrency
:
1
};
}
private
timerId
:
number
|
null
;
changeInterval
=
(
interval
:
number
)
=>
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
interval
:
interval
}));
constructor
(
props
:
{})
{
super
(
props
);
this
.
state
=
{
interval
:
10
,
// sendons
columnList
:
COLUMN
,
experimentUpdateBroadcast
:
0
,
trialsUpdateBroadcast
:
0
,
};
}
}
changeFresh
=
(
fresh
:
string
)
=>
{
// interval * 1000
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
whichPageToFresh
:
fresh
}));
async
componentDidMount
()
{
await
Promise
.
all
([
EXPERIMENT
.
init
(),
TRIALS
.
init
()
]);
this
.
setState
(
state
=>
({
experimentUpdateBroadcast
:
state
.
experimentUpdateBroadcast
+
1
}));
this
.
setState
(
state
=>
({
trialsUpdateBroadcast
:
state
.
trialsUpdateBroadcast
+
1
}));
this
.
timerId
=
window
.
setTimeout
(
this
.
refresh
,
this
.
state
.
interval
*
1000
);
}
}
changeColumn
=
(
columnList
:
Array
<
string
>
)
=>
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
columnList
:
columnList
}));
changeInterval
=
(
interval
:
number
)
=>
{
this
.
setState
({
interval
:
interval
});
if
(
this
.
timerId
===
null
&&
interval
!==
0
)
{
window
.
setTimeout
(
this
.
refresh
);
}
else
if
(
this
.
timerId
!==
null
&&
interval
===
0
)
{
window
.
clearTimeout
(
this
.
timerId
);
}
}
}
changeConcurrency
=
(
val
:
number
)
=>
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(
()
=>
({
concurrency
:
val
})
)
;
// TODO: use local storage
changeColumn
=
(
columnList
:
Array
<
string
>
)
=>
{
this
.
setState
(
{
columnList
:
columnList
});
}
}
getConcurrency
=
()
=>
{
axios
(
`
${
MANAGER_IP
}
/experiment`
,
{
method
:
'
GET
'
})
.
then
(
res
=>
{
if
(
res
.
status
===
200
)
{
const
params
=
res
.
data
.
params
;
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
concurrency
:
params
.
trialConcurrency
}));
}
render
()
{
const
{
interval
,
columnList
,
experimentUpdateBroadcast
,
trialsUpdateBroadcast
}
=
this
.
state
;
if
(
experimentUpdateBroadcast
===
0
||
trialsUpdateBroadcast
===
0
)
{
return
null
;
// TODO: render a loading page
}
const
reactPropsChildren
=
React
.
Children
.
map
(
this
.
props
.
children
,
child
=>
React
.
cloneElement
(
// tslint:disable-next-line:no-any
child
as
React
.
ReactElement
<
any
>
,
{
interval
,
columnList
,
changeColumn
:
this
.
changeColumn
,
experimentUpdateBroadcast
,
trialsUpdateBroadcast
,
})
);
return
(
<
Row
className
=
"nni"
style
=
{
{
minHeight
:
window
.
innerHeight
}
}
>
<
Row
className
=
"header"
>
<
Col
span
=
{
1
}
/>
<
Col
className
=
"headerCon"
span
=
{
22
}
>
<
SlideBar
changeInterval
=
{
this
.
changeInterval
}
/>
</
Col
>
<
Col
span
=
{
1
}
/>
</
Row
>
<
Row
className
=
"contentBox"
>
<
Row
className
=
"content"
>
{
reactPropsChildren
}
</
Row
>
</
Row
>
</
Row
>
);
}
private
refresh
=
async
()
=>
{
const
[
experimentUpdated
,
trialsUpdated
]
=
await
Promise
.
all
([
EXPERIMENT
.
update
(),
TRIALS
.
update
()
]);
if
(
experimentUpdated
)
{
this
.
setState
(
state
=>
({
experimentUpdateBroadcast
:
state
.
experimentUpdateBroadcast
+
1
}));
}
if
(
trialsUpdated
)
{
this
.
setState
(
state
=>
({
trialsUpdateBroadcast
:
state
.
trialsUpdateBroadcast
+
1
}));
}
});
}
componentDidMount
()
{
this
.
_isMounted
=
true
;
this
.
getConcurrency
();
}
if
([
'
DONE
'
,
'
ERROR
'
,
'
STOPPED
'
].
includes
(
EXPERIMENT
.
status
))
{
// experiment finished, refresh once more to ensure consistency
if
(
this
.
state
.
interval
>
0
)
{
this
.
setState
({
interval
:
0
});
this
.
lastRefresh
();
}
componentWillUnmount
()
{
this
.
_isMounted
=
false
;
}
render
()
{
const
{
interval
,
whichPageToFresh
,
columnList
,
concurrency
}
=
this
.
state
;
const
reactPropsChildren
=
React
.
Children
.
map
(
this
.
props
.
children
,
child
=>
React
.
cloneElement
(
// tslint:disable-next-line:no-any
child
as
React
.
ReactElement
<
any
>
,
{
interval
,
whichPageToFresh
,
columnList
,
changeColumn
:
this
.
changeColumn
,
concurrency
,
changeConcurrency
:
this
.
changeConcurrency
})
);
return
(
<
Row
className
=
"nni"
style
=
{
{
minHeight
:
window
.
innerHeight
}
}
>
<
Row
className
=
"header"
>
<
Col
span
=
{
1
}
/>
<
Col
className
=
"headerCon"
span
=
{
22
}
>
<
SlideBar
changeInterval
=
{
this
.
changeInterval
}
changeFresh
=
{
this
.
changeFresh
}
/>
</
Col
>
<
Col
span
=
{
1
}
/>
</
Row
>
<
Row
className
=
"contentBox"
>
<
Row
className
=
"content"
>
{
reactPropsChildren
}
</
Row
>
</
Row
>
</
Row
>
);
}
}
else
if
(
this
.
state
.
interval
!==
0
)
{
this
.
timerId
=
window
.
setTimeout
(
this
.
refresh
,
this
.
state
.
interval
*
1000
);
}
}
private
async
lastRefresh
()
{
await
EXPERIMENT
.
update
();
await
TRIALS
.
update
(
true
);
this
.
setState
(
state
=>
({
experimentUpdateBroadcast
:
state
.
experimentUpdateBroadcast
+
1
}));
this
.
setState
(
state
=>
({
trialsUpdateBroadcast
:
state
.
trialsUpdateBroadcast
+
1
}));
}
}
export
default
App
;
src/webui/src/components/Modal/Compare.tsx
View file @
c785655e
...
...
@@ -2,12 +2,13 @@ import * as React from 'react';
import
{
Row
,
Modal
}
from
'
antd
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
IntermediateVal
from
'
../public-child/IntermediateVal
'
;
import
{
TRIALS
}
from
'
../../static/datamodel
'
;
import
'
../../static/style/compare.scss
'
;
import
{
Table
Obj
,
Intermedia
,
TooltipForIntermediate
}
from
'
src/static/interface
'
;
import
{
Table
Record
,
Intermedia
,
TooltipForIntermediate
}
from
'
src/static/interface
'
;
// the modal of trial compare
interface
CompareProps
{
compareRows
:
Array
<
Table
Obj
>
;
compareRows
:
Array
<
Table
Record
>
;
visible
:
boolean
;
cancelFunc
:
()
=>
void
;
}
...
...
@@ -25,11 +26,12 @@ class Compare extends React.Component<CompareProps, {}> {
const
idsList
:
Array
<
string
>
=
[];
Object
.
keys
(
compareRows
).
map
(
item
=>
{
const
temp
=
compareRows
[
item
];
const
trial
=
TRIALS
.
getTrial
(
temp
.
id
);
trialIntermediate
.
push
({
name
:
temp
.
id
,
data
:
t
emp
.
description
.
intermediate
,
data
:
t
rial
.
description
.
intermediate
,
type
:
'
line
'
,
hyperPara
:
t
emp
.
description
.
parameters
hyperPara
:
t
rial
.
description
.
parameters
});
idsList
.
push
(
temp
.
id
);
});
...
...
@@ -105,11 +107,12 @@ class Compare extends React.Component<CompareProps, {}> {
// render table column ---
initColumn
=
()
=>
{
const
{
compareRows
}
=
this
.
props
;
const
idList
:
Array
<
string
>
=
[];
const
sequenceIdList
:
Array
<
number
>
=
[];
const
durationList
:
Array
<
number
>
=
[];
const
compareRows
=
this
.
props
.
compareRows
.
map
(
tableRecord
=>
TRIALS
.
getTrial
(
tableRecord
.
id
));
const
parameterList
:
Array
<
object
>
=
[];
let
parameterKeys
:
Array
<
string
>
=
[];
if
(
compareRows
.
length
!==
0
)
{
...
...
@@ -147,7 +150,7 @@ class Compare extends React.Component<CompareProps, {}> {
const
temp
=
compareRows
[
index
];
return
(
<
td
className
=
"value"
key
=
{
index
}
>
<
IntermediateVal
recor
d
=
{
temp
}
/>
<
IntermediateVal
trialI
d
=
{
temp
.
id
}
/>
</
td
>
);
})
}
...
...
@@ -206,7 +209,7 @@ class Compare extends React.Component<CompareProps, {}> {
>
<
Row
className
=
"compare-intermediate"
>
{
this
.
intermediate
()
}
<
Row
className
=
"compare-yAxis"
>
# Intermediate
</
Row
>
<
Row
className
=
"compare-yAxis"
>
# Intermediate
result
</
Row
>
</
Row
>
<
Row
>
{
this
.
initColumn
()
}
</
Row
>
</
Modal
>
...
...
src/webui/src/components/Modal/ExperimentDrawer.tsx
View file @
c785655e
...
...
@@ -58,7 +58,7 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> {
trialMessage
:
trialMessagesArr
};
if
(
this
.
_isCompareMount
===
true
)
{
this
.
setState
(()
=>
({
experiment
:
JSON
.
stringify
(
result
,
null
,
4
)
})
)
;
this
.
setState
({
experiment
:
JSON
.
stringify
(
result
,
null
,
4
)
});
}
}
}));
...
...
src/webui/src/components/Modal/LogDrawer.tsx
View file @
c785655e
...
...
@@ -51,13 +51,13 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
setDispatcher
=
(
value
:
string
)
=>
{
if
(
this
.
_isLogDrawer
===
true
)
{
this
.
setState
(()
=>
({
isLoadispatcher
:
false
,
dispatcherLogStr
:
value
})
)
;
this
.
setState
({
isLoadispatcher
:
false
,
dispatcherLogStr
:
value
});
}
}
setNNImanager
=
(
val
:
string
)
=>
{
if
(
this
.
_isLogDrawer
===
true
)
{
this
.
setState
(()
=>
({
isLoading
:
false
,
nniManagerLogStr
:
val
})
)
;
this
.
setState
({
isLoading
:
false
,
nniManagerLogStr
:
val
});
}
}
...
...
src/webui/src/components/Overview.tsx
View file @
c785655e
import
*
as
React
from
'
react
'
;
import
axios
from
'
axios
'
;
import
{
Row
,
Col
}
from
'
antd
'
;
import
{
MANAGER_IP
}
from
'
../static/const
'
;
import
{
Experiment
,
TableObj
,
Parameters
,
TrialNumber
}
from
'
../static/interface
'
;
import
{
getFinal
}
from
'
../static/function
'
;
import
{
EXPERIMENT
,
TRIALS
}
from
'
../static/datamodel
'
;
import
{
Trial
}
from
'
../static/model/trial
'
;
import
SuccessTable
from
'
./overview/SuccessTable
'
;
import
Title1
from
'
./overview/Title1
'
;
import
Progressed
from
'
./overview/Progress
'
;
import
Accuracy
from
'
./overview/Accuracy
'
;
import
SearchSpace
from
'
./overview/SearchSpace
'
;
import
BasicInfo
from
'
./overview/BasicInfo
'
;
import
Trial
Pr
o
from
'
./overview/TrialProfile
'
;
import
Trial
Inf
o
from
'
./overview/TrialProfile
'
;
require
(
'
../static/style/overview.scss
'
);
require
(
'
../static/style/logPath.scss
'
);
...
...
@@ -18,486 +16,70 @@ require('../static/style/accuracy.css');
require
(
'
../static/style/table.scss
'
);
require
(
'
../static/style/overviewTitle.scss
'
);
interface
OverviewState
{
tableData
:
Array
<
TableObj
>
;
experimentAPI
:
object
;
searchSpace
:
object
;
status
:
string
;
errorStr
:
string
;
trialProfile
:
Experiment
;
option
:
object
;
noData
:
string
;
accuracyData
:
object
;
bestAccuracy
:
number
;
accNodata
:
string
;
trialNumber
:
TrialNumber
;
isTop10
:
boolean
;
titleMaxbgcolor
?:
string
;
titleMinbgcolor
?:
string
;
// trial stdout is content(false) or link(true)
isLogCollection
:
boolean
;
isMultiPhase
:
boolean
;
interface
OverviewProps
{
experimentUpdateBroadcast
:
number
;
trialsUpdateBroadcast
:
number
;
}
interface
OverviewProps
{
interval
:
number
;
// user select
whichPageToFresh
:
string
;
concurrency
:
number
;
changeConcurrency
:
(
val
:
number
)
=>
void
;
interface
OverviewState
{
trialConcurrency
:
number
;
metricGraphMode
:
'
max
'
|
'
min
'
;
}
class
Overview
extends
React
.
Component
<
OverviewProps
,
OverviewState
>
{
public
_isMounted
=
false
;
public
intervalID
=
0
;
public
intervalProfile
=
1
;
constructor
(
props
:
OverviewProps
)
{
super
(
props
);
this
.
state
=
{
searchSpace
:
{},
experimentAPI
:
{},
status
:
''
,
errorStr
:
''
,
trialProfile
:
{
id
:
''
,
author
:
''
,
experName
:
''
,
runConcurren
:
1
,
maxDuration
:
0
,
execDuration
:
0
,
MaxTrialNum
:
0
,
startTime
:
0
,
tuner
:
{},
trainingServicePlatform
:
''
},
tableData
:
[],
option
:
{},
noData
:
''
,
// accuracy
accuracyData
:
{},
accNodata
:
''
,
bestAccuracy
:
0
,
trialNumber
:
{
succTrial
:
0
,
failTrial
:
0
,
stopTrial
:
0
,
waitTrial
:
0
,
runTrial
:
0
,
unknowTrial
:
0
,
totalCurrentTrial
:
0
},
isTop10
:
true
,
isLogCollection
:
false
,
isMultiPhase
:
false
trialConcurrency
:
EXPERIMENT
.
trialConcurrency
,
metricGraphMode
:
(
EXPERIMENT
.
optimizeMode
===
'
minimize
'
?
'
min
'
:
'
max
'
),
};
}
// show session
showSessionPro
=
()
=>
{
axios
(
`
${
MANAGER_IP
}
/experiment`
,
{
method
:
'
GET
'
})
.
then
(
res
=>
{
if
(
res
.
status
===
200
)
{
let
sessionData
=
res
.
data
;
let
trialPro
=
[];
const
tempara
=
sessionData
.
params
;
const
trainingPlatform
=
tempara
.
trainingServicePlatform
;
// assessor clusterMeteData
const
clusterMetaData
=
tempara
.
clusterMetaData
;
const
endTimenum
=
sessionData
.
endTime
;
const
assessor
=
tempara
.
assessor
;
const
advisor
=
tempara
.
advisor
;
let
optimizeMode
=
'
other
'
;
if
(
tempara
.
tuner
!==
undefined
)
{
if
(
tempara
.
tuner
.
classArgs
!==
undefined
)
{
if
(
tempara
.
tuner
.
classArgs
.
optimize_mode
!==
undefined
)
{
optimizeMode
=
tempara
.
tuner
.
classArgs
.
optimize_mode
;
}
}
}
// default logCollection is true
const
logCollection
=
tempara
.
logCollection
;
let
expLogCollection
:
boolean
=
false
;
const
isMultiy
:
boolean
=
tempara
.
multiPhase
!==
undefined
?
tempara
.
multiPhase
:
false
;
if
(
optimizeMode
!==
undefined
)
{
if
(
optimizeMode
===
'
minimize
'
)
{
if
(
this
.
_isMounted
)
{
this
.
setState
({
isTop10
:
false
,
titleMinbgcolor
:
'
#999
'
});
}
}
else
{
if
(
this
.
_isMounted
)
{
this
.
setState
({
isTop10
:
true
,
titleMaxbgcolor
:
'
#999
'
});
}
}
}
if
(
logCollection
!==
undefined
&&
logCollection
!==
'
none
'
)
{
expLogCollection
=
true
;
}
trialPro
.
push
({
id
:
sessionData
.
id
,
author
:
tempara
.
authorName
,
revision
:
sessionData
.
revision
,
experName
:
tempara
.
experimentName
,
runConcurren
:
tempara
.
trialConcurrency
,
logDir
:
sessionData
.
logDir
?
sessionData
.
logDir
:
'
undefined
'
,
maxDuration
:
tempara
.
maxExecDuration
,
execDuration
:
sessionData
.
execDuration
,
MaxTrialNum
:
tempara
.
maxTrialNum
,
startTime
:
sessionData
.
startTime
,
endTime
:
endTimenum
?
endTimenum
:
undefined
,
trainingServicePlatform
:
trainingPlatform
,
tuner
:
tempara
.
tuner
,
assessor
:
assessor
?
assessor
:
undefined
,
advisor
:
advisor
?
advisor
:
undefined
,
clusterMetaData
:
clusterMetaData
?
clusterMetaData
:
undefined
,
logCollection
:
logCollection
});
// search space format loguniform max and min
const
temp
=
tempara
.
searchSpace
;
const
searchSpace
=
temp
!==
undefined
?
JSON
.
parse
(
temp
)
:
{};
Object
.
keys
(
searchSpace
).
map
(
item
=>
{
const
key
=
searchSpace
[
item
].
_type
;
let
value
=
searchSpace
[
item
].
_value
;
switch
(
key
)
{
case
'
quniform
'
:
case
'
qnormal
'
:
case
'
qlognormal
'
:
searchSpace
[
item
].
_value
=
[
value
[
0
],
value
[
1
]];
break
;
default
:
}
});
if
(
this
.
_isMounted
)
{
this
.
setState
({
experimentAPI
:
res
.
data
,
trialProfile
:
trialPro
[
0
],
searchSpace
:
searchSpace
,
isLogCollection
:
expLogCollection
,
isMultiPhase
:
isMultiy
});
}
}
});
this
.
checkStatus
();
}
checkStatus
=
()
=>
{
axios
(
`
${
MANAGER_IP
}
/check-status`
,
{
method
:
'
GET
'
})
.
then
(
res
=>
{
if
(
res
.
status
===
200
)
{
const
errors
=
res
.
data
.
errors
;
if
(
errors
.
length
!==
0
)
{
if
(
this
.
_isMounted
)
{
this
.
setState
({
status
:
res
.
data
.
status
,
errorStr
:
res
.
data
.
errors
[
0
]
});
}
}
else
{
if
(
this
.
_isMounted
)
{
this
.
setState
({
status
:
res
.
data
.
status
,
});
}
}
}
});
}
showTrials
=
()
=>
{
this
.
isOffInterval
();
axios
(
`
${
MANAGER_IP
}
/trial-jobs`
,
{
method
:
'
GET
'
})
.
then
(
res
=>
{
if
(
res
.
status
===
200
)
{
const
tableData
=
res
.
data
;
const
topTableData
:
Array
<
TableObj
>
=
[];
const
profile
:
TrialNumber
=
{
succTrial
:
0
,
failTrial
:
0
,
stopTrial
:
0
,
waitTrial
:
0
,
runTrial
:
0
,
unknowTrial
:
0
,
totalCurrentTrial
:
0
};
// currently totoal number
profile
.
totalCurrentTrial
=
tableData
.
length
;
Object
.
keys
(
tableData
).
map
(
item
=>
{
switch
(
tableData
[
item
].
status
)
{
case
'
WAITING
'
:
profile
.
waitTrial
+=
1
;
break
;
case
'
UNKNOWN
'
:
profile
.
unknowTrial
+=
1
;
break
;
case
'
FAILED
'
:
profile
.
failTrial
+=
1
;
break
;
case
'
RUNNING
'
:
profile
.
runTrial
+=
1
;
break
;
case
'
USER_CANCELED
'
:
case
'
SYS_CANCELED
'
:
case
'
EARLY_STOPPED
'
:
profile
.
stopTrial
+=
1
;
break
;
case
'
SUCCEEDED
'
:
profile
.
succTrial
+=
1
;
const
desJobDetail
:
Parameters
=
{
parameters
:
{},
intermediate
:
[],
multiProgress
:
1
};
const
duration
=
(
tableData
[
item
].
endTime
-
tableData
[
item
].
startTime
)
/
1000
;
const
acc
=
getFinal
(
tableData
[
item
].
finalMetricData
);
// if hyperparameters is undefine, show error message, else, show parameters value
const
tempara
=
tableData
[
item
].
hyperParameters
;
if
(
tempara
!==
undefined
)
{
const
tempLength
=
tempara
.
length
;
const
parameters
=
JSON
.
parse
(
tempara
[
tempLength
-
1
]).
parameters
;
desJobDetail
.
multiProgress
=
tempara
.
length
;
if
(
typeof
parameters
===
'
string
'
)
{
desJobDetail
.
parameters
=
JSON
.
parse
(
parameters
);
}
else
{
desJobDetail
.
parameters
=
parameters
;
}
}
else
{
desJobDetail
.
parameters
=
{
error
:
'
This trial
\'
s parameters are not available.
'
};
}
if
(
tableData
[
item
].
logPath
!==
undefined
)
{
desJobDetail
.
logPath
=
tableData
[
item
].
logPath
;
}
topTableData
.
push
({
key
:
topTableData
.
length
,
sequenceId
:
tableData
[
item
].
sequenceId
,
id
:
tableData
[
item
].
id
,
duration
:
duration
,
status
:
tableData
[
item
].
status
,
acc
:
acc
,
description
:
desJobDetail
});
break
;
default
:
}
});
// choose top10 or lowest10
const
{
isTop10
}
=
this
.
state
;
if
(
isTop10
===
true
)
{
topTableData
.
sort
((
a
:
TableObj
,
b
:
TableObj
)
=>
{
if
(
a
.
acc
!==
undefined
&&
b
.
acc
!==
undefined
)
{
return
JSON
.
parse
(
b
.
acc
.
default
)
-
JSON
.
parse
(
a
.
acc
.
default
);
}
else
{
return
NaN
;
}
});
}
else
{
topTableData
.
sort
((
a
:
TableObj
,
b
:
TableObj
)
=>
{
if
(
a
.
acc
!==
undefined
&&
b
.
acc
!==
undefined
)
{
return
JSON
.
parse
(
a
.
acc
.
default
)
-
JSON
.
parse
(
b
.
acc
.
default
);
}
else
{
return
NaN
;
}
});
}
topTableData
.
length
=
Math
.
min
(
10
,
topTableData
.
length
);
let
bestDefaultMetric
=
0
;
if
(
topTableData
[
0
]
!==
undefined
)
{
if
(
topTableData
[
0
].
acc
!==
undefined
)
{
bestDefaultMetric
=
JSON
.
parse
(
topTableData
[
0
].
acc
.
default
);
}
}
if
(
this
.
_isMounted
)
{
this
.
setState
({
tableData
:
topTableData
,
trialNumber
:
profile
,
bestAccuracy
:
bestDefaultMetric
});
}
this
.
checkStatus
();
// draw accuracy
this
.
drawPointGraph
();
}
});
}
// trial accuracy graph Default Metric
drawPointGraph
=
()
=>
{
const
{
tableData
}
=
this
.
state
;
const
sourcePoint
=
JSON
.
parse
(
JSON
.
stringify
(
tableData
));
sourcePoint
.
sort
((
a
:
TableObj
,
b
:
TableObj
)
=>
{
if
(
a
.
sequenceId
!==
undefined
&&
b
.
sequenceId
!==
undefined
)
{
return
a
.
sequenceId
-
b
.
sequenceId
;
}
else
{
return
NaN
;
}
});
const
accarr
:
Array
<
number
>
=
[];
const
indexarr
:
Array
<
number
>
=
[];
Object
.
keys
(
sourcePoint
).
map
(
item
=>
{
const
items
=
sourcePoint
[
item
];
if
(
items
.
acc
!==
undefined
)
{
accarr
.
push
(
items
.
acc
.
default
);
indexarr
.
push
(
items
.
sequenceId
);
}
});
const
accOption
=
{
// support max show 0.0000000
grid
:
{
left
:
67
,
right
:
40
},
tooltip
:
{
trigger
:
'
item
'
},
xAxis
:
{
name
:
'
Trial
'
,
type
:
'
category
'
,
data
:
indexarr
},
yAxis
:
{
name
:
'
Default metric
'
,
type
:
'
value
'
,
scale
:
true
,
data
:
accarr
},
series
:
[{
symbolSize
:
6
,
type
:
'
scatter
'
,
data
:
accarr
}]
};
if
(
this
.
_isMounted
)
{
this
.
setState
({
accuracyData
:
accOption
},
()
=>
{
if
(
accarr
.
length
===
0
)
{
this
.
setState
({
accNodata
:
'
No data
'
});
}
else
{
this
.
setState
({
accNodata
:
''
});
}
});
}
}
clickMaxTop
=
(
event
:
React
.
SyntheticEvent
<
EventTarget
>
)
=>
{
event
.
stopPropagation
();
// #999 panel active bgcolor; #b3b3b3 as usual
this
.
setState
(()
=>
({
isTop10
:
true
,
titleMaxbgcolor
:
'
#999
'
,
titleMinbgcolor
:
'
#b3b3b3
'
}));
this
.
showTrials
();
this
.
setState
({
metricGraphMode
:
'
max
'
});
}
clickMinTop
=
(
event
:
React
.
SyntheticEvent
<
EventTarget
>
)
=>
{
event
.
stopPropagation
();
this
.
setState
(()
=>
({
isTop10
:
false
,
titleMaxbgcolor
:
'
#b3b3b3
'
,
titleMinbgcolor
:
'
#999
'
}));
this
.
showTrials
();
}
isOffInterval
=
()
=>
{
const
{
status
}
=
this
.
state
;
const
{
interval
}
=
this
.
props
;
if
(
status
===
'
DONE
'
||
status
===
'
ERROR
'
||
status
===
'
STOPPED
'
||
interval
===
0
)
{
window
.
clearInterval
(
this
.
intervalID
);
window
.
clearInterval
(
this
.
intervalProfile
);
return
;
}
this
.
setState
({
metricGraphMode
:
'
min
'
});
}
componentWillReceiveProps
(
nextProps
:
OverviewProps
)
{
const
{
interval
,
whichPageToFresh
}
=
nextProps
;
window
.
clearInterval
(
this
.
intervalID
);
window
.
clearInterval
(
this
.
intervalProfile
);
if
(
whichPageToFresh
.
includes
(
'
/oview
'
))
{
this
.
showTrials
();
this
.
showSessionPro
();
}
if
(
interval
!==
0
)
{
this
.
intervalID
=
window
.
setInterval
(
this
.
showTrials
,
interval
*
1000
);
this
.
intervalProfile
=
window
.
setInterval
(
this
.
showSessionPro
,
interval
*
1000
);
}
changeConcurrency
=
(
val
:
number
)
=>
{
this
.
setState
({
trialConcurrency
:
val
});
}
componentDidMount
()
{
this
.
_isMounted
=
true
;
const
{
interval
}
=
this
.
props
;
this
.
showTrials
();
this
.
showSessionPro
();
if
(
interval
!==
0
)
{
this
.
intervalID
=
window
.
setInterval
(
this
.
showTrials
,
interval
*
1000
);
this
.
intervalProfile
=
window
.
setInterval
(
this
.
showSessionPro
,
interval
*
1000
);
}
}
render
()
{
const
{
trialConcurrency
,
metricGraphMode
}
=
this
.
state
;
const
{
experimentUpdateBroadcast
}
=
this
.
props
;
componentWillUnmount
()
{
this
.
_isMounted
=
false
;
window
.
clearInterval
(
this
.
intervalID
);
window
.
clearInterval
(
this
.
intervalProfile
);
}
const
searchSpace
=
this
.
convertSearchSpace
();
render
()
{
const
bestTrials
=
this
.
findBestTrials
();
const
bestAccuracy
=
bestTrials
.
length
>
0
?
bestTrials
[
0
].
accuracy
!
:
NaN
;
const
accuracyGraphData
=
this
.
generateAccuracyGraph
(
bestTrials
);
const
noDataMessage
=
bestTrials
.
length
>
0
?
''
:
'
No data
'
;
const
{
trialProfile
,
searchSpace
,
tableData
,
accuracyData
,
accNodata
,
status
,
errorStr
,
trialNumber
,
bestAccuracy
,
isMultiPhase
,
titleMaxbgcolor
,
titleMinbgcolor
,
isLogCollection
,
experimentAPI
}
=
this
.
state
;
const
{
concurrency
}
=
this
.
props
;
trialProfile
.
runConcurren
=
concurrency
;
Object
.
keys
(
experimentAPI
).
map
(
item
=>
{
if
(
item
===
'
params
'
)
{
const
temp
=
experimentAPI
[
item
];
Object
.
keys
(
temp
).
map
(
index
=>
{
if
(
index
===
'
trialConcurrency
'
)
{
temp
[
index
]
=
concurrency
;
}
});
}
});
const
titleMaxbgcolor
=
(
metricGraphMode
===
'
max
'
?
'
#999
'
:
'
#b3b3b3
'
);
const
titleMinbgcolor
=
(
metricGraphMode
===
'
min
'
?
'
#999
'
:
'
#b3b3b3
'
);
return
(
<
div
className
=
"overview"
>
{
/* status and experiment block */
}
<
Row
>
<
Title1
text
=
"Experiment"
icon
=
"11.png"
/>
<
BasicInfo
trialProfile
=
{
trialProfile
}
status
=
{
status
}
/>
<
BasicInfo
experimentUpdateBroadcast
=
{
experimentUpdateBroadcast
}
/>
</
Row
>
<
Row
className
=
"overMessage"
>
{
/* status graph */
}
<
Col
span
=
{
9
}
className
=
"prograph overviewBoder cc"
>
<
Title1
text
=
"Status"
icon
=
"5.png"
/>
<
Progressed
trialNumber
=
{
trialNumber
}
trialProfile
=
{
trialProfile
}
bestAccuracy
=
{
bestAccuracy
}
status
=
{
status
}
errors
=
{
errorStr
}
concurrency
=
{
concurrency
}
changeConcurrency
=
{
this
.
props
.
changeConcurrency
}
concurrency
=
{
trialConcurrency
}
changeConcurrency
=
{
this
.
changeConcurrency
}
experimentUpdateBroadcast
=
{
experimentUpdateBroadcast
}
/>
</
Col
>
{
/* experiment parameters search space tuner assessor... */
}
...
...
@@ -512,7 +94,10 @@ class Overview extends React.Component<OverviewProps, OverviewState> {
<
Row
className
=
"experiment"
>
{
/* the scroll bar all the trial profile in the searchSpace div*/
}
<
div
className
=
"experiment searchSpace"
>
<
TrialPro
experiment
=
{
experimentAPI
}
/>
<
TrialInfo
experimentUpdateBroadcast
=
{
experimentUpdateBroadcast
}
concurrency
=
{
trialConcurrency
}
/>
</
div
>
</
Row
>
</
Col
>
...
...
@@ -541,24 +126,79 @@ class Overview extends React.Component<OverviewProps, OverviewState> {
<
Col
span
=
{
8
}
className
=
"overviewBoder"
>
<
Row
className
=
"accuracy"
>
<
Accuracy
accuracyData
=
{
accuracyData
}
accNodata
=
{
accNodata
}
accuracyData
=
{
accuracy
Graph
Data
}
accNodata
=
{
noDataMessage
}
height
=
{
324
}
/>
</
Row
>
</
Col
>
<
Col
span
=
{
16
}
id
=
"succeTable"
>
<
SuccessTable
tableSource
=
{
tableData
}
multiphase
=
{
isMultiPhase
}
logCollection
=
{
isLogCollection
}
trainingPlatform
=
{
trialProfile
.
trainingServicePlatform
}
/>
<
SuccessTable
trialIds
=
{
bestTrials
.
map
(
trial
=>
trial
.
info
.
id
)
}
/>
</
Col
>
</
Row
>
</
Row
>
</
div
>
);
}
private
convertSearchSpace
():
object
{
const
searchSpace
=
Object
.
assign
({},
EXPERIMENT
.
searchSpace
);
Object
.
keys
(
searchSpace
).
map
(
item
=>
{
const
key
=
searchSpace
[
item
].
_type
;
let
value
=
searchSpace
[
item
].
_value
;
switch
(
key
)
{
case
'
quniform
'
:
case
'
qnormal
'
:
case
'
qlognormal
'
:
searchSpace
[
item
].
_value
=
[
value
[
0
],
value
[
1
]];
break
;
default
:
}
});
return
searchSpace
;
}
private
findBestTrials
():
Trial
[]
{
let
bestTrials
=
TRIALS
.
sort
();
if
(
this
.
state
.
metricGraphMode
===
'
max
'
)
{
bestTrials
.
reverse
().
splice
(
10
);
}
else
{
bestTrials
.
splice
(
10
);
}
return
bestTrials
;
}
private
generateAccuracyGraph
(
bestTrials
:
Trial
[]):
object
{
const
xSequence
=
bestTrials
.
map
(
trial
=>
trial
.
sequenceId
);
const
ySequence
=
bestTrials
.
map
(
trial
=>
trial
.
accuracy
);
return
{
// support max show 0.0000000
grid
:
{
left
:
67
,
right
:
40
},
tooltip
:
{
trigger
:
'
item
'
},
xAxis
:
{
name
:
'
Trial
'
,
type
:
'
category
'
,
data
:
xSequence
},
yAxis
:
{
name
:
'
Default metric
'
,
type
:
'
value
'
,
scale
:
true
,
data
:
ySequence
},
series
:
[{
symbolSize
:
6
,
type
:
'
scatter
'
,
data
:
ySequence
}]
};
}
}
export
default
Overview
;
src/webui/src/components/SlideBar.tsx
View file @
c785655e
...
...
@@ -26,7 +26,6 @@ interface SliderState {
interface
SliderProps
extends
FormComponentProps
{
changeInterval
:
(
value
:
number
)
=>
void
;
changeFresh
:
(
value
:
string
)
=>
void
;
}
interface
EventPer
{
...
...
@@ -35,7 +34,6 @@ interface EventPer {
class
SlideBar
extends
React
.
Component
<
SliderProps
,
SliderState
>
{
public
_isMounted
=
false
;
public
divMenu
:
HTMLDivElement
|
null
;
public
selectHTML
:
Select
|
null
;
...
...
@@ -57,32 +55,26 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
method
:
'
GET
'
})
.
then
(
res
=>
{
if
(
res
.
status
===
200
&&
this
.
_isMounted
)
{
if
(
res
.
status
===
200
)
{
this
.
setState
({
version
:
res
.
data
});
}
});
}
handleMenuClick
=
(
e
:
EventPer
)
=>
{
if
(
this
.
_isMounted
)
{
this
.
setState
({
menuVisible
:
false
});
}
this
.
setState
({
menuVisible
:
false
});
switch
(
e
.
key
)
{
// to see & download experiment parameters
case
'
1
'
:
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
isvisibleExperimentDrawer
:
true
}));
}
this
.
setState
({
isvisibleExperimentDrawer
:
true
});
break
;
// to see & download nnimanager log
case
'
2
'
:
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
activeKey
:
'
nnimanager
'
,
isvisibleLogDrawer
:
true
}));
}
this
.
setState
({
activeKey
:
'
nnimanager
'
,
isvisibleLogDrawer
:
true
});
break
;
// to see & download dispatcher log
case
'
3
'
:
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
isvisibleLogDrawer
:
true
,
activeKey
:
'
dispatcher
'
}));
}
this
.
setState
({
isvisibleLogDrawer
:
true
,
activeKey
:
'
dispatcher
'
});
break
;
case
'
close
'
:
case
'
10
'
:
...
...
@@ -96,13 +88,10 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
}
handleVisibleChange
=
(
flag
:
boolean
)
=>
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
({
menuVisible
:
flag
});
}
this
.
setState
({
menuVisible
:
flag
});
}
getInterval
=
(
value
:
string
)
=>
{
if
(
value
===
'
close
'
)
{
this
.
props
.
changeInterval
(
0
);
}
else
{
...
...
@@ -203,13 +192,9 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
fresh
=
(
event
:
React
.
SyntheticEvent
<
EventTarget
>
)
=>
{
event
.
preventDefault
();
event
.
stopPropagation
();
if
(
this
.
_isMounted
)
{
this
.
setState
({
isdisabledFresh
:
true
},
()
=>
{
const
whichPage
=
window
.
location
.
pathname
;
this
.
props
.
changeFresh
(
whichPage
);
setTimeout
(()
=>
{
this
.
setState
(()
=>
({
isdisabledFresh
:
false
}));
},
1000
);
});
}
this
.
setState
({
isdisabledFresh
:
true
},
()
=>
{
setTimeout
(()
=>
{
this
.
setState
({
isdisabledFresh
:
false
});
},
1000
);
});
}
desktopHTML
=
()
=>
{
...
...
@@ -330,27 +315,18 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
}
// close log drawer (nnimanager.dispatcher)
closeLogDrawer
=
()
=>
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
isvisibleLogDrawer
:
false
,
activeKey
:
''
}));
}
this
.
setState
({
isvisibleLogDrawer
:
false
,
activeKey
:
''
});
}
// close download experiment parameters drawer
closeExpDrawer
=
()
=>
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
isvisibleExperimentDrawer
:
false
}));
}
this
.
setState
({
isvisibleExperimentDrawer
:
false
});
}
componentDidMount
()
{
this
.
_isMounted
=
true
;
this
.
getNNIversion
();
}
componentWillUnmount
()
{
this
.
_isMounted
=
false
;
}
render
()
{
const
mobile
=
(<
MediaQuery
maxWidth
=
{
884
}
>
{
this
.
mobileHTML
()
}
</
MediaQuery
>);
const
tablet
=
(<
MediaQuery
minWidth
=
{
885
}
maxWidth
=
{
1241
}
>
{
this
.
tabeltHTML
()
}
</
MediaQuery
>);
...
...
@@ -376,4 +352,4 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
}
}
export
default
Form
.
create
<
FormComponentProps
>
()(
SlideBar
);
\ No newline at end of file
export
default
Form
.
create
<
FormComponentProps
>
()(
SlideBar
);
src/webui/src/components/TrialsDetail.tsx
View file @
c785655e
import
*
as
React
from
'
react
'
;
import
axios
from
'
axios
'
;
import
{
MANAGER_IP
}
from
'
../static/const
'
;
import
{
Row
,
Col
,
Tabs
,
Select
,
Button
,
Icon
}
from
'
antd
'
;
const
Option
=
Select
.
Option
;
import
{
TableObj
,
Parameters
,
ExperimentInfo
}
from
'
../static/
interface
'
;
import
{
getFin
al
}
from
'
../static/
function
'
;
import
{
EXPERIMENT
,
TRIALS
}
from
'
../static/
datamodel
'
;
import
{
Tri
al
}
from
'
../static/
model/trial
'
;
import
DefaultPoint
from
'
./trial-detail/DefaultMetricPoint
'
;
import
Duration
from
'
./trial-detail/Duration
'
;
import
Title1
from
'
./overview/Title1
'
;
...
...
@@ -16,37 +14,22 @@ import '../static/style/trialsDetail.scss';
import
'
../static/style/search.scss
'
;
interface
TrialDetailState
{
accSource
:
object
;
accNodata
:
string
;
tableListSource
:
Array
<
TableObj
>
;
searchResultSource
:
Array
<
TableObj
>
;
isHasSearch
:
boolean
;
experimentLogCollection
:
boolean
;
entriesTable
:
number
;
// table components val
entriesInSelect
:
string
;
searchSpace
:
string
;
isMultiPhase
:
boolean
;
tablePageSize
:
number
;
// table components val
whichGraph
:
string
;
hyperCounts
:
number
;
// user click the hyper-parameter counts
durationCounts
:
number
;
intermediateCounts
:
number
;
experimentInfo
:
ExperimentInfo
;
searchFilter
:
string
;
searchPlaceHolder
:
string
;
searchType
:
string
;
searchFilter
:
(
trial
:
Trial
)
=>
boolean
;
}
interface
TrialsDetailProps
{
interval
:
number
;
whichPageToFresh
:
string
;
columnList
:
Array
<
string
>
;
changeColumn
:
(
val
:
Array
<
string
>
)
=>
void
;
experimentUpdateBroacast
:
number
;
trialsUpdateBroadcast
:
number
;
}
class
TrialsDetail
extends
React
.
Component
<
TrialsDetailProps
,
TrialDetailState
>
{
public
_isMounted
=
false
;
public
interAccuracy
=
0
;
public
interTableList
=
1
;
public
interAllTableList
=
2
;
public
tableList
:
TableList
|
null
;
...
...
@@ -73,335 +56,67 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
constructor
(
props
:
TrialsDetailProps
)
{
super
(
props
);
this
.
state
=
{
accSource
:
{},
accNodata
:
''
,
tableListSource
:
[],
searchResultSource
:
[],
experimentLogCollection
:
false
,
entriesTable
:
20
,
entriesInSelect
:
'
20
'
,
searchSpace
:
''
,
tablePageSize
:
20
,
whichGraph
:
'
1
'
,
isHasSearch
:
false
,
isMultiPhase
:
false
,
hyperCounts
:
0
,
durationCounts
:
0
,
intermediateCounts
:
0
,
experimentInfo
:
{
platform
:
''
,
optimizeMode
:
'
maximize
'
},
searchFilter
:
'
id
'
,
searchPlaceHolder
:
'
Search by id
'
searchType
:
'
id
'
,
searchFilter
:
trial
=>
true
,
};
}
getDetailSource
=
()
=>
{
this
.
isOffIntervals
();
axios
.
all
([
axios
.
get
(
`
${
MANAGER_IP
}
/trial-jobs`
),
axios
.
get
(
`
${
MANAGER_IP
}
/metric-data`
)
])
.
then
(
axios
.
spread
((
res
,
res1
)
=>
{
if
(
res
.
status
===
200
&&
res1
.
status
===
200
)
{
const
trialJobs
=
res
.
data
;
const
metricSource
=
res1
.
data
;
const
trialTable
:
Array
<
TableObj
>
=
[];
Object
.
keys
(
trialJobs
).
map
(
item
=>
{
let
desc
:
Parameters
=
{
parameters
:
{},
intermediate
:
[],
multiProgress
:
1
};
let
duration
=
0
;
const
id
=
trialJobs
[
item
].
id
!==
undefined
?
trialJobs
[
item
].
id
:
''
;
const
status
=
trialJobs
[
item
].
status
!==
undefined
?
trialJobs
[
item
].
status
:
''
;
const
begin
=
trialJobs
[
item
].
startTime
;
const
end
=
trialJobs
[
item
].
endTime
;
if
(
begin
)
{
if
(
end
)
{
duration
=
(
end
-
begin
)
/
1000
;
}
else
{
duration
=
(
new
Date
().
getTime
()
-
begin
)
/
1000
;
}
}
const
tempHyper
=
trialJobs
[
item
].
hyperParameters
;
if
(
tempHyper
!==
undefined
)
{
const
getPara
=
JSON
.
parse
(
tempHyper
[
tempHyper
.
length
-
1
]).
parameters
;
desc
.
multiProgress
=
tempHyper
.
length
;
if
(
typeof
getPara
===
'
string
'
)
{
desc
.
parameters
=
JSON
.
parse
(
getPara
);
}
else
{
desc
.
parameters
=
getPara
;
}
}
else
{
desc
.
parameters
=
{
error
:
'
This trial
\'
s parameters are not available.
'
};
}
if
(
trialJobs
[
item
].
logPath
!==
undefined
)
{
desc
.
logPath
=
trialJobs
[
item
].
logPath
;
}
const
acc
=
getFinal
(
trialJobs
[
item
].
finalMetricData
);
// deal with intermediate result list
const
mediate
:
Array
<
number
>
=
[];
Object
.
keys
(
metricSource
).
map
(
key
=>
{
const
items
=
metricSource
[
key
];
if
(
items
.
trialJobId
===
id
)
{
// succeed trial, last intermediate result is final result
// final result format may be object
if
(
typeof
JSON
.
parse
(
items
.
data
)
===
'
object
'
)
{
mediate
.
push
(
JSON
.
parse
(
items
.
data
).
default
);
}
else
{
mediate
.
push
(
JSON
.
parse
(
items
.
data
));
}
}
});
desc
.
intermediate
=
mediate
;
trialTable
.
push
({
key
:
trialTable
.
length
,
sequenceId
:
trialJobs
[
item
].
sequenceId
,
id
:
id
,
status
:
status
,
duration
:
duration
,
acc
:
acc
,
description
:
desc
,
startTime
:
begin
,
endTime
:
(
end
!==
undefined
)
?
end
:
undefined
});
});
// update search data result
const
{
searchResultSource
,
entriesInSelect
}
=
this
.
state
;
if
(
searchResultSource
.
length
!==
0
)
{
const
temp
:
Array
<
number
>
=
[];
Object
.
keys
(
searchResultSource
).
map
(
index
=>
{
temp
.
push
(
searchResultSource
[
index
].
id
);
});
const
searchResultList
:
Array
<
TableObj
>
=
[];
for
(
let
i
=
0
;
i
<
temp
.
length
;
i
++
)
{
Object
.
keys
(
trialTable
).
map
(
key
=>
{
const
item
=
trialTable
[
key
];
if
(
item
.
id
===
temp
[
i
])
{
searchResultList
.
push
(
item
);
}
});
}
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
searchResultSource
:
searchResultList
}));
}
}
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
tableListSource
:
trialTable
}));
}
if
(
entriesInSelect
===
'
all
'
&&
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
entriesTable
:
trialTable
.
length
}));
}
}
}));
}
// search a trial by trial No. & trial id
searchTrial
=
(
event
:
React
.
ChangeEvent
<
HTMLInputElement
>
)
=>
{
const
targetValue
=
event
.
target
.
value
;
if
(
targetValue
===
''
||
targetValue
===
'
'
)
{
const
{
tableListSource
}
=
this
.
state
;
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
isHasSearch
:
false
,
tableListSource
:
tableListSource
,
}));
}
}
else
{
const
{
tableListSource
,
searchFilter
}
=
this
.
state
;
const
searchResultList
:
Array
<
TableObj
>
=
[];
Object
.
keys
(
tableListSource
).
map
(
key
=>
{
const
item
=
tableListSource
[
key
];
switch
(
searchFilter
)
{
case
'
id
'
:
if
(
item
.
id
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
()))
{
searchResultList
.
push
(
item
);
}
break
;
case
'
Trial No.
'
:
if
(
item
.
sequenceId
.
toString
()
===
targetValue
)
{
searchResultList
.
push
(
item
);
}
break
;
case
'
status
'
:
if
(
item
.
status
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
()))
{
searchResultList
.
push
(
item
);
}
break
;
case
'
parameters
'
:
const
strParameters
=
JSON
.
stringify
(
item
.
description
.
parameters
,
null
,
4
);
if
(
strParameters
.
includes
(
targetValue
))
{
searchResultList
.
push
(
item
);
}
break
;
default
:
}
});
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
searchResultSource
:
searchResultList
,
isHasSearch
:
true
}));
}
}
}
// close timer
isOffIntervals
=
()
=>
{
const
{
interval
}
=
this
.
props
;
if
(
interval
===
0
)
{
window
.
clearInterval
(
this
.
interTableList
);
let
filter
=
(
trial
:
Trial
)
=>
true
;
if
(
!
targetValue
.
trim
())
{
this
.
setState
({
searchFilter
:
filter
});
return
;
}
else
{
axios
(
`
${
MANAGER_IP
}
/check-status`
,
{
method
:
'
GET
'
})
.
then
(
res
=>
{
if
(
res
.
status
===
200
&&
this
.
_isMounted
)
{
const
expStatus
=
res
.
data
.
status
;
if
(
expStatus
===
'
DONE
'
||
expStatus
===
'
ERROR
'
||
expStatus
===
'
STOPPED
'
)
{
window
.
clearInterval
(
this
.
interTableList
);
return
;
}
}
});
}
switch
(
this
.
state
.
searchType
)
{
case
'
id
'
:
filter
=
trial
=>
trial
.
info
.
id
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
());
break
;
case
'
Trial No.
'
:
filter
=
trial
=>
trial
.
info
.
sequenceId
.
toString
()
===
targetValue
;
break
;
case
'
status
'
:
filter
=
trial
=>
trial
.
info
.
status
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
());
break
;
case
'
parameters
'
:
// TODO: support filters like `x: 2` (instead of `"x": 2`)
filter
=
trial
=>
JSON
.
stringify
(
trial
.
info
.
hyperParameters
,
null
,
4
).
includes
(
targetValue
);
break
;
default
:
alert
(
`Unexpected search filter
${
this
.
state
.
searchType
}
`
);
}
this
.
setState
({
searchFilter
:
filter
});
}
handleEntriesSelect
=
(
value
:
string
)
=>
{
// user select isn't 'all'
if
(
value
!==
'
all
'
)
{
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
entriesTable
:
parseInt
(
value
,
10
)
}));
}
}
else
{
const
{
tableListSource
}
=
this
.
state
;
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
entriesInSelect
:
'
all
'
,
entriesTable
:
tableListSource
.
length
}));
}
}
handleTablePageSizeSelect
=
(
value
:
string
)
=>
{
this
.
setState
({
tablePageSize
:
value
===
'
all
'
?
-
1
:
parseInt
(
value
,
10
)
});
}
handleWhichTabs
=
(
activeKey
:
string
)
=>
{
// const which = JSON.parse(activeKey);
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
whichGraph
:
activeKey
}));
}
this
.
setState
({
whichGraph
:
activeKey
});
}
test
=
()
=>
{
alert
(
'
TableList component was not properly initialized.
'
);
}
get
SearchFilter
=
(
value
:
string
)
=>
{
update
SearchFilter
Type
=
(
value
:
string
)
=>
{
// clear input value and re-render table
if
(
this
.
searchInput
!==
null
)
{
this
.
searchInput
.
value
=
''
;
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
isHasSearch
:
false
}));
}
}
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
searchFilter
:
value
,
searchPlaceHolder
:
`Search by
${
value
}
`
}));
}
}
// get and set logCollection val
checkExperimentPlatform
=
()
=>
{
axios
(
`
${
MANAGER_IP
}
/experiment`
,
{
method
:
'
GET
'
})
.
then
(
res
=>
{
if
(
res
.
status
===
200
)
{
const
trainingPlatform
:
string
=
res
.
data
.
params
.
trainingServicePlatform
!==
undefined
?
res
.
data
.
params
.
trainingServicePlatform
:
''
;
// default logCollection is true
const
logCollection
=
res
.
data
.
params
.
logCollection
;
let
expLogCollection
:
boolean
=
false
;
const
isMultiy
:
boolean
=
res
.
data
.
params
.
multiPhase
!==
undefined
?
res
.
data
.
params
.
multiPhase
:
false
;
const
tuner
=
res
.
data
.
params
.
tuner
;
// I'll set optimize is maximize if user not set optimize
let
optimize
:
string
=
'
maximize
'
;
if
(
tuner
!==
undefined
)
{
if
(
tuner
.
classArgs
!==
undefined
)
{
if
(
tuner
.
classArgs
.
optimize_mode
!==
undefined
)
{
if
(
tuner
.
classArgs
.
optimize_mode
===
'
minimize
'
)
{
optimize
=
'
minimize
'
;
}
}
}
}
if
(
logCollection
!==
undefined
&&
logCollection
!==
'
none
'
)
{
expLogCollection
=
true
;
}
if
(
this
.
_isMounted
)
{
this
.
setState
({
experimentInfo
:
{
platform
:
trainingPlatform
,
optimizeMode
:
optimize
},
searchSpace
:
res
.
data
.
params
.
searchSpace
,
experimentLogCollection
:
expLogCollection
,
isMultiPhase
:
isMultiy
});
}
}
});
}
componentWillReceiveProps
(
nextProps
:
TrialsDetailProps
)
{
const
{
interval
,
whichPageToFresh
}
=
nextProps
;
window
.
clearInterval
(
this
.
interTableList
);
if
(
interval
!==
0
)
{
this
.
interTableList
=
window
.
setInterval
(
this
.
getDetailSource
,
interval
*
1000
);
}
if
(
whichPageToFresh
.
includes
(
'
/detail
'
))
{
this
.
getDetailSource
();
}
}
componentDidMount
()
{
this
.
_isMounted
=
true
;
const
{
interval
}
=
this
.
props
;
this
.
getDetailSource
();
this
.
interTableList
=
window
.
setInterval
(
this
.
getDetailSource
,
interval
*
1000
);
this
.
checkExperimentPlatform
();
}
componentWillUnmount
()
{
this
.
_isMounted
=
false
;
window
.
clearInterval
(
this
.
interTableList
);
this
.
setState
({
searchType
:
value
});
}
render
()
{
const
{
tableListSource
,
searchResultSource
,
isHasSearch
,
isMultiPhase
,
entriesTable
,
experimentInfo
,
searchSpace
,
experimentLogCollection
,
whichGraph
,
searchPlaceHolder
}
=
this
.
state
;
const
source
=
isHasSearch
?
searchResultSource
:
tableListSource
;
const
{
tablePageSize
,
whichGraph
}
=
this
.
state
;
const
{
columnList
,
changeColumn
}
=
this
.
props
;
const
source
=
TRIALS
.
filter
(
this
.
state
.
searchFilter
);
const
trialIds
=
TRIALS
.
filter
(
this
.
state
.
searchFilter
).
map
(
trial
=>
trial
.
id
);
return
(
<
div
>
<
div
className
=
"trial"
id
=
"tabsty"
>
...
...
@@ -409,10 +124,9 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<
TabPane
tab
=
{
this
.
titleOfacc
}
key
=
"1"
>
<
Row
className
=
"graph"
>
<
DefaultPoint
height
=
{
402
}
showSource
=
{
source
}
whichGraph
=
{
whichGraph
}
optimize
=
{
experimentInfo
.
optimizeMode
}
trialIds
=
{
trialIds
}
visible
=
{
whichGraph
===
'
1
'
}
trialsUpdateBroadcast
=
{
this
.
props
.
trialsUpdateBroadcast
}
/>
</
Row
>
</
TabPane
>
...
...
@@ -420,7 +134,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<
Row
className
=
"graph"
>
<
Para
dataSource
=
{
source
}
expSearchSpace
=
{
searchSpace
}
expSearchSpace
=
{
JSON
.
stringify
(
EXPERIMENT
.
searchSpace
)
}
whichGraph
=
{
whichGraph
}
/>
</
Row
>
...
...
@@ -440,7 +154,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<
span
>
Show
</
span
>
<
Select
className
=
"entry"
onSelect
=
{
this
.
handle
Entries
Select
}
onSelect
=
{
this
.
handle
TablePageSize
Select
}
defaultValue
=
"20"
>
<
Option
value
=
"20"
>
20
</
Option
>
...
...
@@ -464,7 +178,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
>
Compare
</
Button
>
<
Select
defaultValue
=
"id"
className
=
"filter"
onSelect
=
{
this
.
get
SearchFilter
}
>
<
Select
defaultValue
=
"id"
className
=
"filter"
onSelect
=
{
this
.
update
SearchFilter
Type
}
>
<
Option
value
=
"id"
>
Id
</
Option
>
<
Option
value
=
"Trial No."
>
Trial No.
</
Option
>
<
Option
value
=
"status"
>
Status
</
Option
>
...
...
@@ -473,7 +187,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<
input
type
=
"text"
className
=
"search-input"
placeholder
=
{
s
earch
PlaceHolder
}
placeholder
=
{
`S
earch
by
${
this
.
state
.
searchType
}
`
}
onChange
=
{
this
.
searchTrial
}
style
=
{
{
width
:
230
}
}
ref
=
{
text
=>
(
this
.
searchInput
)
=
text
}
...
...
@@ -481,14 +195,11 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
</
Col
>
</
Row
>
<
TableList
entries
=
{
entriesTable
}
tableSource
=
{
source
}
isMultiPhase
=
{
isMultiPhase
}
platform
=
{
experimentInfo
.
platform
}
updateList
=
{
this
.
getDetailSource
}
logCollection
=
{
experimentLogCollection
}
pageSize
=
{
tablePageSize
}
tableSource
=
{
source
.
map
(
trial
=>
trial
.
tableRecord
)
}
columnList
=
{
columnList
}
changeColumn
=
{
changeColumn
}
trialsUpdateBroadcast
=
{
this
.
props
.
trialsUpdateBroadcast
}
ref
=
{
(
tabList
)
=>
this
.
tableList
=
tabList
}
/>
</
div
>
...
...
@@ -496,4 +207,4 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
}
}
export
default
TrialsDetail
;
\ No newline at end of file
export
default
TrialsDetail
;
src/webui/src/components/overview/BasicInfo.tsx
View file @
c785655e
import
{
Col
,
Row
,
Tooltip
}
from
'
antd
'
;
import
*
as
React
from
'
react
'
;
import
{
Row
,
Col
,
Tooltip
}
from
'
antd
'
;
import
{
Experiment
}
from
'
../../static/interface
'
;
import
{
EXPERIMENT
}
from
'
../../static/datamodel
'
;
import
{
formatTimestamp
}
from
'
../../static/function
'
;
interface
BasicInfoProps
{
trialProfile
:
Experiment
;
status
:
string
;
experimentUpdateBroadcast
:
number
;
}
class
BasicInfo
extends
React
.
Component
<
BasicInfoProps
,
{}
>
{
constructor
(
props
:
BasicInfoProps
)
{
super
(
props
);
}
render
()
{
const
{
trialProfile
}
=
this
.
props
;
return
(
<
Row
className
=
"main"
>
<
Col
span
=
{
8
}
className
=
"padItem basic"
>
<
p
>
Name
</
p
>
<
div
>
{
trialProfile
.
exper
Name
}
</
div
>
<
div
>
{
EXPERIMENT
.
profile
.
params
.
experiment
Name
}
</
div
>
<
p
>
ID
</
p
>
<
div
>
{
trialP
rofile
.
id
}
</
div
>
<
div
>
{
EXPERIMENT
.
p
rofile
.
id
}
</
div
>
</
Col
>
<
Col
span
=
{
8
}
className
=
"padItem basic"
>
<
p
>
Start time
</
p
>
<
div
className
=
"nowrap"
>
{
new
Date
(
trialProfile
.
startTime
).
toLocaleString
(
'
en-US
'
)
}
</
div
>
<
div
className
=
"nowrap"
>
{
formatTimestamp
(
EXPERIMENT
.
profile
.
startTime
)
}
</
div
>
<
p
>
End time
</
p
>
<
div
className
=
"nowrap"
>
{
trialProfile
.
endTime
?
new
Date
(
trialProfile
.
endTime
).
toLocaleString
(
'
en-US
'
)
:
'
none
'
}
</
div
>
<
div
className
=
"nowrap"
>
{
formatTimestamp
(
EXPERIMENT
.
profile
.
endTime
)
}
</
div
>
</
Col
>
<
Col
span
=
{
8
}
className
=
"padItem basic"
>
<
p
>
Log directory
</
p
>
<
div
className
=
"nowrap"
>
<
Tooltip
placement
=
"top"
title
=
{
trialP
rofile
.
logDir
}
>
{
trialP
rofile
.
logDir
}
<
Tooltip
placement
=
"top"
title
=
{
EXPERIMENT
.
p
rofile
.
logDir
||
''
}
>
{
EXPERIMENT
.
p
rofile
.
logDir
||
'
unknown
'
}
</
Tooltip
>
</
div
>
<
p
>
Training platform
</
p
>
<
div
className
=
"nowrap"
>
{
trialProfile
.
trainingServicePlatform
?
trialProfile
.
trainingServicePlatform
:
'
none
'
}
</
div
>
<
div
className
=
"nowrap"
>
{
EXPERIMENT
.
profile
.
params
.
trainingServicePlatform
}
</
div
>
</
Col
>
</
Row
>
);
}
}
export
default
BasicInfo
;
\ No newline at end of file
export
default
BasicInfo
;
src/webui/src/components/overview/NumInput.tsx
0 → 100644
View file @
c785655e
import
*
as
React
from
'
react
'
;
import
{
Button
,
Row
}
from
'
antd
'
;
interface
ConcurrencyInputProps
{
value
:
number
;
updateValue
:
(
val
:
string
)
=>
void
;
}
interface
ConcurrencyInputStates
{
editting
:
boolean
;
}
class
ConcurrencyInput
extends
React
.
Component
<
ConcurrencyInputProps
,
ConcurrencyInputStates
>
{
private
input
=
React
.
createRef
<
HTMLInputElement
>
();
constructor
(
props
:
ConcurrencyInputProps
)
{
super
(
props
);
this
.
state
=
{
editting
:
false
};
}
save
=
()
=>
{
if
(
this
.
input
.
current
!==
null
)
{
this
.
props
.
updateValue
(
this
.
input
.
current
.
value
);
this
.
setState
({
editting
:
false
});
}
}
cancel
=
()
=>
{
this
.
setState
({
editting
:
false
});
}
edit
=
()
=>
{
this
.
setState
({
editting
:
true
});
}
render
()
{
if
(
this
.
state
.
editting
)
{
return
(
<
Row
className
=
"inputBox"
>
<
input
type
=
"number"
className
=
"concurrencyInput"
defaultValue
=
{
this
.
props
.
value
.
toString
()
}
ref
=
{
this
.
input
}
/>
<
Button
type
=
"primary"
className
=
"tableButton editStyle"
onClick
=
{
this
.
save
}
>
Save
</
Button
>
<
Button
type
=
"primary"
onClick
=
{
this
.
cancel
}
style
=
{
{
display
:
'
inline-block
'
,
marginLeft
:
1
}
}
className
=
"tableButton editStyle"
>
Cancel
</
Button
>
</
Row
>
);
}
else
{
return
(
<
Row
className
=
"inputBox"
>
<
input
type
=
"number"
className
=
"concurrencyInput"
disabled
=
{
true
}
value
=
{
this
.
props
.
value
}
/>
<
Button
type
=
"primary"
className
=
"tableButton editStyle"
onClick
=
{
this
.
edit
}
>
Edit
</
Button
>
</
Row
>
);
}
}
}
export
default
ConcurrencyInput
;
src/webui/src/components/overview/Progress.tsx
View file @
c785655e
import
*
as
React
from
'
react
'
;
import
{
Row
,
Col
,
Popover
,
Button
,
message
}
from
'
antd
'
;
import
{
Row
,
Col
,
Popover
,
message
}
from
'
antd
'
;
import
axios
from
'
axios
'
;
import
{
MANAGER_IP
,
CONTROLTYPE
}
from
'
../../static/const
'
;
import
{
E
xperiment
,
TrialNumber
}
from
'
../../static/
interface
'
;
import
{
MANAGER_IP
}
from
'
../../static/const
'
;
import
{
E
XPERIMENT
,
TRIALS
}
from
'
../../static/
datamodel
'
;
import
{
convertTime
}
from
'
../../static/function
'
;
import
ConcurrencyInput
from
'
./NumInput
'
;
import
ProgressBar
from
'
./ProgressItem
'
;
import
LogDrawer
from
'
../Modal/LogDrawer
'
;
import
'
../../static/style/progress.scss
'
;
import
'
../../static/style/probar.scss
'
;
interface
ProgressProps
{
trialProfile
:
Experiment
;
concurrency
:
number
;
trialNumber
:
TrialNumber
;
bestAccuracy
:
number
;
status
:
string
;
errors
:
string
;
changeConcurrency
:
(
val
:
number
)
=>
void
;
experimentUpdateBroadcast
:
number
;
}
interface
ProgressState
{
btnName
:
string
;
isEnable
:
boolean
;
userInputVal
:
string
;
// get user input
cancelSty
:
string
;
isShowLogDrawer
:
boolean
;
}
class
Progressed
extends
React
.
Component
<
ProgressProps
,
ProgressState
>
{
public
conInput
:
HTMLInputElement
|
null
;
public
_isMounted
=
false
;
constructor
(
props
:
ProgressProps
)
{
super
(
props
);
this
.
state
=
{
btnName
:
'
Edit
'
,
isEnable
:
true
,
userInputVal
:
this
.
props
.
trialProfile
.
runConcurren
.
toString
(),
cancelSty
:
'
none
'
,
isShowLogDrawer
:
false
};
}
editTrialConcurrency
=
()
=>
{
const
{
btnName
}
=
this
.
state
;
if
(
this
.
_isMounted
)
{
if
(
btnName
===
'
Edit
'
)
{
// user click edit
this
.
setState
(()
=>
({
isEnable
:
false
,
btnName
:
'
Save
'
,
cancelSty
:
'
inline-block
'
}));
}
else
{
// user click save button
axios
(
`
${
MANAGER_IP
}
/experiment`
,
{
method
:
'
GET
'
})
.
then
(
rese
=>
{
if
(
rese
.
status
===
200
)
{
const
{
userInputVal
}
=
this
.
state
;
const
experimentFile
=
rese
.
data
;
const
trialConcurrency
=
experimentFile
.
params
.
trialConcurrency
;
if
(
userInputVal
!==
undefined
)
{
if
(
userInputVal
===
trialConcurrency
.
toString
()
||
userInputVal
===
'
0
'
)
{
message
.
destroy
();
message
.
info
(
`trialConcurrency's value is
${
trialConcurrency
}
, you did not modify it`
,
2
);
}
else
{
experimentFile
.
params
.
trialConcurrency
=
parseInt
(
userInputVal
,
10
);
// rest api, modify trial concurrency value
axios
(
`
${
MANAGER_IP
}
/experiment`
,
{
method
:
'
PUT
'
,
headers
:
{
'
Content-Type
'
:
'
application/json;charset=utf-8
'
},
data
:
experimentFile
,
params
:
{
update_type
:
CONTROLTYPE
[
1
]
}
}).
then
(
res
=>
{
if
(
res
.
status
===
200
)
{
message
.
destroy
();
message
.
success
(
`Update
${
CONTROLTYPE
[
1
].
toLocaleLowerCase
()}
successfully`
);
this
.
props
.
changeConcurrency
(
parseInt
(
userInputVal
,
10
));
}
})
.
catch
(
error
=>
{
if
(
error
.
response
.
status
===
500
)
{
if
(
error
.
response
.
data
.
error
)
{
message
.
error
(
error
.
response
.
data
.
error
);
}
else
{
message
.
error
(
`Update
${
CONTROLTYPE
[
1
].
toLocaleLowerCase
()}
failed`
);
}
}
});
// btn -> edit
this
.
setState
(()
=>
({
btnName
:
'
Edit
'
,
isEnable
:
true
,
cancelSty
:
'
none
'
}));
}
}
}
});
}
}
}
cancelFunction
=
()
=>
{
const
{
trialProfile
}
=
this
.
props
;
if
(
this
.
_isMounted
)
{
this
.
setState
(
()
=>
({
btnName
:
'
Edit
'
,
isEnable
:
true
,
cancelSty
:
'
none
'
,
}));
editTrialConcurrency
=
async
(
userInput
:
string
)
=>
{
if
(
!
userInput
.
match
(
/^
[
1-9
]\d
*$/
))
{
message
.
error
(
'
Please enter a positive integer!
'
,
2
);
return
;
}
if
(
this
.
conInput
!==
null
)
{
this
.
conInput
.
value
=
trialProfile
.
runConcurren
.
toString
();
const
newConcurrency
=
parseInt
(
userInput
,
10
);
if
(
newConcurrency
===
this
.
props
.
concurrency
)
{
message
.
info
(
`Trial concurrency has not changed`
,
2
);
return
;
}
}
getUserTrialConcurrency
=
(
event
:
React
.
ChangeEvent
<
HTMLInputElement
>
)
=>
{
const
value
=
event
.
target
.
value
;
if
(
value
.
match
(
/^
[
1-9
]\d
*$/
)
||
value
===
''
)
{
this
.
setState
(()
=>
({
userInputVal
:
value
}));
}
else
{
message
.
error
(
'
Please enter a positive integer!
'
,
2
);
if
(
this
.
conInput
!==
null
)
{
const
{
trialProfile
}
=
this
.
props
;
this
.
conInput
.
value
=
trialProfile
.
runConcurren
.
toString
();
const
newProfile
=
Object
.
assign
({},
EXPERIMENT
.
profile
);
newProfile
.
params
.
trialConcurrency
=
newConcurrency
;
// rest api, modify trial concurrency value
try
{
const
res
=
await
axios
.
put
(
`
${
MANAGER_IP
}
/experiment`
,
newProfile
,
{
params
:
{
update_type
:
'
TRIAL_CONCURRENCY
'
}
});
if
(
res
.
status
===
200
)
{
message
.
success
(
`Successfully updated trial concurrency`
);
// NOTE: should we do this earlier in favor of poor networks?
this
.
props
.
changeConcurrency
(
newConcurrency
);
}
}
catch
(
error
)
{
if
(
error
.
response
&&
error
.
response
.
data
.
error
)
{
message
.
error
(
`Failed to update trial concurrency\n
${
error
.
response
.
data
.
error
}
`
);
}
else
if
(
error
.
response
)
{
message
.
error
(
`Failed to update trial concurrency\nServer responsed
${
error
.
response
.
status
}
`
);
}
else
if
(
error
.
message
)
{
message
.
error
(
`Failed to update trial concurrency\n
${
error
.
message
}
`
);
}
else
{
message
.
error
(
`Failed to update trial concurrency\nUnknown error`
);
}
}
}
isShowDrawer
=
()
=>
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
isShowLogDrawer
:
true
}));
}
this
.
setState
({
isShowLogDrawer
:
true
});
}
closeDrawer
=
()
=>
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
isShowLogDrawer
:
false
}));
}
this
.
setState
({
isShowLogDrawer
:
false
});
}
componentWillReceiveProps
()
{
const
{
trialProfile
}
=
this
.
props
;
if
(
this
.
conInput
!==
null
)
{
this
.
conInput
.
value
=
trialProfile
.
runConcurren
.
toString
();
}
}
render
()
{
const
{
bestAccuracy
}
=
this
.
props
;
const
{
isShowLogDrawer
}
=
this
.
state
;
co
mponentDidMount
()
{
this
.
_isMounted
=
true
;
}
co
nst
count
=
TRIALS
.
countStatus
();
const
stoppedCount
=
count
.
get
(
'
USER_CANCELED
'
)
!
+
count
.
get
(
'
SYS_CANCELED
'
)
!
+
count
.
get
(
'
EARLY_STOPPED
'
)
!
;
const
bar2
=
count
.
get
(
'
RUNNING
'
)
!
+
count
.
get
(
'
SUCCEEDED
'
)
!
+
count
.
get
(
'
FAILED
'
)
!
+
stoppedCount
;
componentWillUnmount
()
{
this
.
_isMounted
=
false
;
}
const
bar2Percent
=
(
bar2
/
EXPERIMENT
.
profile
.
params
.
maxTrialNum
)
*
100
;
const
percent
=
(
EXPERIMENT
.
profile
.
execDuration
/
EXPERIMENT
.
profile
.
params
.
maxExecDuration
)
*
100
;
const
remaining
=
convertTime
(
EXPERIMENT
.
profile
.
params
.
maxExecDuration
-
EXPERIMENT
.
profile
.
execDuration
);
const
maxDuration
=
convertTime
(
EXPERIMENT
.
profile
.
params
.
maxExecDuration
);
const
maxTrialNum
=
EXPERIMENT
.
profile
.
params
.
maxTrialNum
;
const
execDuration
=
convertTime
(
EXPERIMENT
.
profile
.
execDuration
);
render
()
{
const
{
trialProfile
,
trialNumber
,
bestAccuracy
,
status
,
errors
}
=
this
.
props
;
const
{
isEnable
,
btnName
,
cancelSty
,
isShowLogDrawer
}
=
this
.
state
;
const
bar2
=
trialNumber
.
totalCurrentTrial
-
trialNumber
.
waitTrial
-
trialNumber
.
unknowTrial
;
const
bar2Percent
=
(
bar2
/
trialProfile
.
MaxTrialNum
)
*
100
;
const
percent
=
(
trialProfile
.
execDuration
/
trialProfile
.
maxDuration
)
*
100
;
const
runDuration
=
convertTime
(
trialProfile
.
execDuration
);
const
temp
=
trialProfile
.
maxDuration
-
trialProfile
.
execDuration
;
let
remaining
;
let
errorContent
;
if
(
temp
<
0
)
{
remaining
=
'
0
'
;
}
else
{
remaining
=
convertTime
(
temp
);
}
if
(
errors
!==
''
)
{
if
(
EXPERIMENT
.
error
)
{
errorContent
=
(
<
div
className
=
"errors"
>
{
error
s
}
{
EXPERIMENT
.
error
}
<
div
><
a
href
=
"#"
onClick
=
{
this
.
isShowDrawer
}
>
Learn about
</
a
></
div
>
</
div
>
);
...
...
@@ -196,9 +103,9 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
<
Row
className
=
"basic lineBasic"
>
<
p
>
Status
</
p
>
<
div
className
=
"status"
>
<
span
className
=
{
status
}
>
{
status
}
</
span
>
<
span
className
=
{
EXPERIMENT
.
status
}
>
{
EXPERIMENT
.
status
}
</
span
>
{
status
===
'
ERROR
'
EXPERIMENT
.
status
===
'
ERROR
'
?
<
Popover
placement
=
"rightTop"
...
...
@@ -216,26 +123,26 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
<
ProgressBar
who
=
"Duration"
percent
=
{
percent
}
description
=
{
run
Duration
}
bgclass
=
{
status
}
maxString
=
{
`Max duration:
${
convertTime
(
trialProfile
.
maxDuration
)
}
`
}
description
=
{
exec
Duration
}
bgclass
=
{
EXPERIMENT
.
status
}
maxString
=
{
`Max duration:
${
maxDuration
}
`
}
/>
<
ProgressBar
who
=
"Trial numbers"
percent
=
{
bar2Percent
}
description
=
{
bar2
.
toString
()
}
bgclass
=
{
status
}
maxString
=
{
`Max trial number:
${
trialProfile
.
M
axTrialNum
}
`
}
bgclass
=
{
EXPERIMENT
.
status
}
maxString
=
{
`Max trial number:
${
m
axTrialNum
}
`
}
/>
<
Row
className
=
"basic colorOfbasic mess"
>
<
p
>
Best metric
</
p
>
<
div
>
{
bestAccuracy
.
toFixed
(
6
)
}
</
div
>
<
div
>
{
isNaN
(
bestAccuracy
)
?
'
N/A
'
:
bestAccuracy
.
toFixed
(
6
)
}
</
div
>
</
Row
>
<
Row
className
=
"mess"
>
<
Col
span
=
{
6
}
>
<
Row
className
=
"basic colorOfbasic"
>
<
p
>
Spent
</
p
>
<
div
>
{
convertTime
(
trialProfile
.
execDuration
)
}
</
div
>
<
div
>
{
execDuration
}
</
div
>
</
Row
>
</
Col
>
<
Col
span
=
{
6
}
>
...
...
@@ -247,54 +154,32 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
<
Col
span
=
{
12
}
>
{
/* modify concurrency */
}
<
p
>
Concurrency
</
p
>
<
Row
className
=
"inputBox"
>
<
input
type
=
"number"
disabled
=
{
isEnable
}
onChange
=
{
this
.
getUserTrialConcurrency
}
className
=
"concurrencyInput"
ref
=
{
(
input
)
=>
this
.
conInput
=
input
}
/>
<
Button
type
=
"primary"
className
=
"tableButton editStyle"
onClick
=
{
this
.
editTrialConcurrency
}
>
{
btnName
}
</
Button
>
<
Button
type
=
"primary"
onClick
=
{
this
.
cancelFunction
}
style
=
{
{
display
:
cancelSty
,
marginLeft
:
1
}
}
className
=
"tableButton editStyle"
>
Cancel
</
Button
>
</
Row
>
<
ConcurrencyInput
value
=
{
this
.
props
.
concurrency
}
updateValue
=
{
this
.
editTrialConcurrency
}
/>
</
Col
>
</
Row
>
<
Row
className
=
"mess"
>
<
Col
span
=
{
6
}
>
<
Row
className
=
"basic colorOfbasic"
>
<
p
>
Running
</
p
>
<
div
>
{
trialNumber
.
runTrial
}
</
div
>
<
div
>
{
count
.
get
(
'
RUNNING
'
)
}
</
div
>
</
Row
>
</
Col
>
<
Col
span
=
{
6
}
>
<
Row
className
=
"basic colorOfbasic"
>
<
p
>
Succeeded
</
p
>
<
div
>
{
trialNumber
.
succTrial
}
</
div
>
<
div
>
{
count
.
get
(
'
SUCCEEDED
'
)
}
</
div
>
</
Row
>
</
Col
>
<
Col
span
=
{
6
}
>
<
Row
className
=
"basic"
>
<
p
>
Stopped
</
p
>
<
div
>
{
trialNumber
.
stopTrial
}
</
div
>
<
div
>
{
stoppedCount
}
</
div
>
</
Row
>
</
Col
>
<
Col
span
=
{
6
}
>
<
Row
className
=
"basic"
>
<
p
>
Failed
</
p
>
<
div
>
{
trialNumber
.
failTrial
}
</
div
>
<
div
>
{
count
.
get
(
'
FAILED
'
)
}
</
div
>
</
Row
>
</
Col
>
</
Row
>
...
...
@@ -309,4 +194,4 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
}
}
export
default
Progressed
;
\ No newline at end of file
export
default
Progressed
;
src/webui/src/components/overview/SuccessTable.tsx
View file @
c785655e
...
...
@@ -2,131 +2,83 @@ import * as React from 'react';
import
{
Table
}
from
'
antd
'
;
import
OpenRow
from
'
../public-child/OpenRow
'
;
import
DefaultMetric
from
'
../public-child/DefaultMetrc
'
;
import
{
TableObj
}
from
'
../../static/interface
'
;
import
{
TRIALS
}
from
'
../../static/datamodel
'
;
import
{
TableRecord
}
from
'
../../static/interface
'
;
import
{
convertDuration
}
from
'
../../static/function
'
;
import
'
../../static/style/tableStatus.css
'
;
import
'
../../static/style/openRow.scss
'
;
interface
SuccessTableProps
{
tableSource
:
Array
<
TableObj
>
;
trainingPlatform
:
string
;
logCollection
:
boolean
;
multiphase
:
boolean
;
trialIds
:
string
[];
}
class
SuccessTable
extends
React
.
Component
<
SuccessTableProps
,
{}
>
{
public
_isMounted
=
false
;
function
openRow
(
record
:
TableRecord
)
{
return
(
<
OpenRow
trialId
=
{
record
.
id
}
/>
);
}
class
SuccessTable
extends
React
.
Component
<
SuccessTableProps
,
{}
>
{
constructor
(
props
:
SuccessTableProps
)
{
super
(
props
);
}
openRow
=
(
record
:
TableObj
)
=>
{
const
{
trainingPlatform
,
logCollection
,
multiphase
}
=
this
.
props
;
return
(
<
OpenRow
trainingPlatform
=
{
trainingPlatform
}
record
=
{
record
}
logCollection
=
{
logCollection
}
multiphase
=
{
multiphase
}
/>
);
}
componentDidMount
()
{
this
.
_isMounted
=
true
;
}
componentWillUnmount
()
{
this
.
_isMounted
=
false
;
}
render
()
{
const
{
tableSource
}
=
this
.
props
;
let
bgColor
=
''
;
const
columns
=
[{
title
:
'
Trial No.
'
,
dataIndex
:
'
sequenceId
'
,
key
:
'
sequenceId
'
,
width
:
140
,
className
:
'
tableHead
'
},
{
title
:
'
ID
'
,
dataIndex
:
'
id
'
,
key
:
'
id
'
,
width
:
60
,
className
:
'
tableHead leftTitle
'
,
render
:
(
text
:
string
,
record
:
TableObj
)
=>
{
return
(
<
div
>
{
record
.
id
}
</
div
>
);
},
},
{
title
:
'
Duration
'
,
dataIndex
:
'
duration
'
,
key
:
'
duration
'
,
width
:
140
,
sorter
:
(
a
:
TableObj
,
b
:
TableObj
)
=>
(
a
.
duration
as
number
)
-
(
b
.
duration
as
number
),
render
:
(
text
:
string
,
record
:
TableObj
)
=>
{
let
duration
;
if
(
record
.
duration
!==
undefined
)
{
// duration is nagative number(-1) & 0-1
if
(
record
.
duration
>
0
&&
record
.
duration
<
1
||
record
.
duration
<
0
)
{
duration
=
`
${
record
.
duration
}
s`
;
}
else
{
duration
=
convertDuration
(
record
.
duration
);
}
}
else
{
duration
=
0
;
const
columns
=
[
{
title
:
'
Trial No.
'
,
dataIndex
:
'
sequenceId
'
,
width
:
140
,
className
:
'
tableHead
'
},
{
title
:
'
ID
'
,
dataIndex
:
'
id
'
,
width
:
60
,
className
:
'
tableHead leftTitle
'
,
render
:
(
text
:
string
,
record
:
TableRecord
)
=>
{
return
(
<
div
>
{
record
.
id
}
</
div
>
);
},
},
{
title
:
'
Duration
'
,
dataIndex
:
'
duration
'
,
width
:
140
,
render
:
(
text
:
string
,
record
:
TableRecord
)
=>
{
return
(
<
div
className
=
"durationsty"
><
div
>
{
convertDuration
(
record
.
duration
)
}
</
div
></
div
>
);
},
},
{
title
:
'
Status
'
,
dataIndex
:
'
status
'
,
width
:
150
,
className
:
'
tableStatus
'
,
render
:
(
text
:
string
,
record
:
TableRecord
)
=>
{
return
(
<
div
className
=
{
`
${
record
.
status
}
commonStyle`
}
>
{
record
.
status
}
</
div
>
);
}
return
(
<
div
className
=
"durationsty"
><
div
>
{
duration
}
</
div
></
div
>
);
},
},
{
title
:
'
Status
'
,
dataIndex
:
'
status
'
,
key
:
'
status
'
,
width
:
150
,
className
:
'
tableStatus
'
,
render
:
(
text
:
string
,
record
:
TableObj
)
=>
{
bgColor
=
record
.
status
;
return
(
<
div
className
=
{
`
${
bgColor
}
commonStyle`
}
>
{
record
.
status
}
</
div
>
);
}
},
{
title
:
'
Default metric
'
,
dataIndex
:
'
acc
'
,
key
:
'
acc
'
,
sorter
:
(
a
:
TableObj
,
b
:
TableObj
)
=>
{
if
(
a
.
acc
!==
undefined
&&
b
.
acc
!==
undefined
)
{
return
JSON
.
parse
(
a
.
acc
.
default
)
-
JSON
.
parse
(
b
.
acc
.
default
);
}
else
{
return
NaN
;
},
{
title
:
'
Default metric
'
,
dataIndex
:
'
accuracy
'
,
render
:
(
text
:
string
,
record
:
TableRecord
)
=>
{
return
(
<
DefaultMetric
trialId
=
{
record
.
id
}
/>
);
}
},
render
:
(
text
:
string
,
record
:
TableObj
)
=>
{
return
(
<
DefaultMetric
record
=
{
record
}
/>
);
}
}
];
];
return
(
<
div
className
=
"tabScroll"
>
<
Table
columns
=
{
columns
}
expandedRowRender
=
{
this
.
openRow
}
dataSource
=
{
tableSource
}
expandedRowRender
=
{
openRow
}
dataSource
=
{
TRIALS
.
table
(
this
.
props
.
trialIds
)
}
className
=
"commonTableStyle"
pagination
=
{
false
}
/>
</
div
>
</
div
>
);
}
}
...
...
src/webui/src/components/overview/TrialProfile.tsx
View file @
c785655e
import
*
as
React
from
'
react
'
;
import
MonacoEditor
from
'
react-monaco-editor
'
;
import
{
MONACO
}
from
'
../../static/const
'
;
import
{
EXPERIMENT
}
from
'
../../static/datamodel
'
;
interface
TrialInfoProps
{
experiment
:
object
;
experimentUpdateBroadcast
:
number
;
concurrency
:
number
;
}
class
TrialInfo
extends
React
.
Component
<
TrialInfoProps
,
{}
>
{
...
...
@@ -12,32 +14,21 @@ class TrialInfo extends React.Component<TrialInfoProps, {}> {
super
(
props
);
}
componentWillReceiveProps
(
nextProps
:
TrialInfoProps
)
{
const
experiments
=
nextProps
.
experiment
;
Object
.
keys
(
experiments
).
map
(
key
=>
{
switch
(
key
)
{
case
'
id
'
:
case
'
logDir
'
:
case
'
startTime
'
:
case
'
endTime
'
:
experiments
[
key
]
=
undefined
;
break
;
case
'
params
'
:
const
params
=
experiments
[
key
];
Object
.
keys
(
params
).
map
(
item
=>
{
if
(
item
===
'
experimentName
'
||
item
===
'
searchSpace
'
||
item
===
'
trainingServicePlatform
'
)
{
params
[
item
]
=
undefined
;
}
});
break
;
default
:
render
()
{
const
blacklist
=
[
'
id
'
,
'
logDir
'
,
'
startTime
'
,
'
endTime
'
,
'
experimentName
'
,
'
searchSpace
'
,
'
trainingServicePlatform
'
];
// tslint:disable-next-line:no-any
const
filter
=
(
key
:
string
,
val
:
any
)
=>
{
if
(
key
===
'
trialConcurrency
'
)
{
return
this
.
props
.
concurrency
;
}
});
}
return
blacklist
.
includes
(
key
)
?
undefined
:
val
;
};
const
profile
=
JSON
.
stringify
(
EXPERIMENT
.
profile
,
filter
,
2
);
render
()
{
const
{
experiment
}
=
this
.
props
;
// FIXME: highlight not working?
return
(
<
div
className
=
"profile"
>
<
MonacoEditor
...
...
@@ -45,7 +36,7 @@ class TrialInfo extends React.Component<TrialInfoProps, {}> {
height
=
"361"
language
=
"json"
theme
=
"vs-light"
value
=
{
JSON
.
stringify
(
experiment
,
null
,
2
)
}
value
=
{
profile
}
options
=
{
MONACO
}
/>
</
div
>
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment