Unverified Commit c785655e authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #207 from microsoft/master

merge master
parents 9fae194a d6b61e2f
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
ppo_tuner.py including:
class PPOTuner
"""
import os
import copy
import logging
import numpy as np
import json_tricks
from gym import spaces
import nni
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
from .model import Model
from .util import set_global_seeds
from .policy import build_lstm_policy
logger = logging.getLogger('ppo_tuner_AutoML')
def constfn(val):
"""wrap as function"""
def f(_):
return val
return f
class ModelConfig:
"""
Configurations of the PPO model
"""
def __init__(self):
self.observation_space = None
self.action_space = None
self.num_envs = 0
self.nsteps = 0
self.ent_coef = 0.0
self.lr = 3e-4
self.vf_coef = 0.5
self.max_grad_norm = 0.5
self.gamma = 0.99
self.lam = 0.95
self.cliprange = 0.2
self.embedding_size = None # the embedding is for each action
self.noptepochs = 4 # number of training epochs per update
self.total_timesteps = 5000 # number of timesteps (i.e. number of actions taken in the environment)
self.nminibatches = 4 # number of training minibatches per update. For recurrent policies,
# should be smaller or equal than number of environments run in parallel.
class TrialsInfo:
"""
Informations of each trial from one model inference
"""
def __init__(self, obs, actions, values, neglogpacs, dones, last_value, inf_batch_size):
self.iter = 0
self.obs = obs
self.actions = actions
self.values = values
self.neglogpacs = neglogpacs
self.dones = dones
self.last_value = last_value
self.rewards = None
self.returns = None
self.inf_batch_size = inf_batch_size
#self.states = None
def get_next(self):
"""
get actions of the next trial
"""
if self.iter >= self.inf_batch_size:
return None, None
actions = []
for step in self.actions:
actions.append(step[self.iter])
self.iter += 1
return self.iter - 1, actions
def update_rewards(self, rewards, returns):
"""
after the trial is finished, reward and return of this trial is updated
"""
self.rewards = rewards
self.returns = returns
def convert_shape(self):
"""
convert shape
"""
def sf01(arr):
"""
swap and then flatten axes 0 and 1
"""
s = arr.shape
return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])
self.obs = sf01(self.obs)
self.returns = sf01(self.returns)
self.dones = sf01(self.dones)
self.actions = sf01(self.actions)
self.values = sf01(self.values)
self.neglogpacs = sf01(self.neglogpacs)
class PPOModel:
"""
PPO Model
"""
def __init__(self, model_config, mask):
self.model_config = model_config
self.states = None # initial state of lstm in policy/value network
self.nupdates = None # the number of func train is invoked, used to tune lr and cliprange
self.cur_update = 1 # record the current update
self.np_mask = mask # record the mask of each action within one trial
set_global_seeds(None)
assert isinstance(self.model_config.lr, float)
self.lr = constfn(self.model_config.lr)
assert isinstance(self.model_config.cliprange, float)
self.cliprange = constfn(self.model_config.cliprange)
# build lstm policy network, value share the same network
policy = build_lstm_policy(model_config)
# Get the nb of env
nenvs = model_config.num_envs
# Calculate the batch_size
self.nbatch = nbatch = nenvs * model_config.nsteps # num of record per update
nbatch_train = nbatch // model_config.nminibatches # get batch size
# self.nupdates is used to tune lr and cliprange
self.nupdates = self.model_config.total_timesteps // self.nbatch
# Instantiate the model object (that creates act_model and train_model)
self.model = Model(policy=policy, nbatch_act=nenvs, nbatch_train=nbatch_train,
nsteps=model_config.nsteps, ent_coef=model_config.ent_coef, vf_coef=model_config.vf_coef,
max_grad_norm=model_config.max_grad_norm, np_mask=self.np_mask)
self.states = self.model.initial_state
logger.info('=== finished PPOModel initialization')
def inference(self, num):
"""
generate actions along with related info from policy network.
observation is the action of the last step.
Parameters:
----------
num: the number of trials to generate
"""
# Here, we init the lists that will contain the mb of experiences
mb_obs, mb_actions, mb_values, mb_dones, mb_neglogpacs = [], [], [], [], []
# initial observation
# use the (n+1)th embedding to represent the first step action
first_step_ob = self.model_config.action_space.n
obs = [first_step_ob for _ in range(num)]
dones = [True for _ in range(num)]
states = self.states
# For n in range number of steps
for cur_step in range(self.model_config.nsteps):
# Given observations, get action value and neglopacs
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
actions, values, states, neglogpacs = self.model.step(cur_step, obs, S=states, M=dones)
mb_obs.append(obs.copy())
mb_actions.append(actions)
mb_values.append(values)
mb_neglogpacs.append(neglogpacs)
mb_dones.append(dones)
# Take actions in env and look the results
# Infos contains a ton of useful informations
obs[:] = actions
if cur_step == self.model_config.nsteps - 1:
dones = [True for _ in range(num)]
else:
dones = [False for _ in range(num)]
#batch of steps to batch of rollouts
np_obs = np.asarray(obs)
mb_obs = np.asarray(mb_obs, dtype=np_obs.dtype)
mb_actions = np.asarray(mb_actions)
mb_values = np.asarray(mb_values, dtype=np.float32)
mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
mb_dones = np.asarray(mb_dones, dtype=np.bool)
last_values = self.model.value(np_obs, S=states, M=dones)
return mb_obs, mb_actions, mb_values, mb_neglogpacs, mb_dones, last_values
def compute_rewards(self, trials_info, trials_result):
"""
compute the rewards of the trials in trials_info based on trials_result,
and update the rewards in trials_info
Parameters:
----------
trials_info: info of the generated trials
trials_result: final results (e.g., acc) of the generated trials
"""
mb_rewards = np.asarray([trials_result for _ in trials_info.actions], dtype=np.float32)
# discount/bootstrap off value fn
mb_returns = np.zeros_like(mb_rewards)
mb_advs = np.zeros_like(mb_rewards)
lastgaelam = 0
last_dones = np.asarray([True for _ in trials_result], dtype=np.bool) # ugly
for t in reversed(range(self.model_config.nsteps)):
if t == self.model_config.nsteps - 1:
nextnonterminal = 1.0 - last_dones
nextvalues = trials_info.last_value
else:
nextnonterminal = 1.0 - trials_info.dones[t+1]
nextvalues = trials_info.values[t+1]
delta = mb_rewards[t] + self.model_config.gamma * nextvalues * nextnonterminal - trials_info.values[t]
mb_advs[t] = lastgaelam = delta + self.model_config.gamma * self.model_config.lam * nextnonterminal * lastgaelam
mb_returns = mb_advs + trials_info.values
trials_info.update_rewards(mb_rewards, mb_returns)
trials_info.convert_shape()
def train(self, trials_info, nenvs):
"""
train the policy/value network using trials_info
Parameters:
----------
trials_info: complete info of the generated trials from the previous inference
nenvs: the batch size of the (previous) inference
"""
# keep frac decay for future optimization
if self.cur_update <= self.nupdates:
frac = 1.0 - (self.cur_update - 1.0) / self.nupdates
else:
logger.warning('current update (self.cur_update) %d has exceeded total updates (self.nupdates) %d',
self.cur_update, self.nupdates)
frac = 1.0 - (self.nupdates - 1.0) / self.nupdates
lrnow = self.lr(frac)
cliprangenow = self.cliprange(frac)
self.cur_update += 1
states = self.states
assert states is not None # recurrent version
assert nenvs % self.model_config.nminibatches == 0
envsperbatch = nenvs // self.model_config.nminibatches
envinds = np.arange(nenvs)
flatinds = np.arange(nenvs * self.model_config.nsteps).reshape(nenvs, self.model_config.nsteps)
for _ in range(self.model_config.noptepochs):
np.random.shuffle(envinds)
for start in range(0, nenvs, envsperbatch):
end = start + envsperbatch
mbenvinds = envinds[start:end]
mbflatinds = flatinds[mbenvinds].ravel()
slices = (arr[mbflatinds] for arr in (trials_info.obs, trials_info.returns, trials_info.dones,
trials_info.actions, trials_info.values, trials_info.neglogpacs))
mbstates = states[mbenvinds]
self.model.train(lrnow, cliprangenow, *slices, mbstates)
class PPOTuner(Tuner):
"""
PPOTuner
"""
def __init__(self, optimize_mode, trials_per_update=20, epochs_per_update=4, minibatch_size=4,
ent_coef=0.0, lr=3e-4, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, cliprange=0.2):
"""
initialization, PPO model is not initialized here as search space is not received yet.
Parameters:
----------
optimize_mode: maximize or minimize
trials_per_update: number of trials to have for each model update
epochs_per_update: number of epochs to run for each model update
minibatch_size: minibatch size (number of trials) for the update
ent_coef: policy entropy coefficient in the optimization objective
lr: learning rate of the model (lstm network), constant
vf_coef: value function loss coefficient in the optimization objective
max_grad_norm: gradient norm clipping coefficient
gamma: discounting factor
lam: advantage estimation discounting factor (lambda in the paper)
cliprange: cliprange in the PPO algorithm, constant
"""
self.optimize_mode = OptimizeMode(optimize_mode)
self.model_config = ModelConfig()
self.model = None
self.search_space = None
self.running_trials = {} # key: parameter_id, value: actions/states/etc.
self.inf_batch_size = trials_per_update # number of trials to generate in one inference
self.first_inf = True # indicate whether it is the first time to inference new trials
self.trials_result = [None for _ in range(self.inf_batch_size)] # results of finished trials
self.credit = 0 # record the unsatisfied trial requests
self.param_ids = []
self.finished_trials = 0
self.chosen_arch_template = {}
self.actions_spaces = None
self.actions_to_config = None
self.full_act_space = None
self.trials_info = None
self.all_trials = {} # used to dedup the same trial, key: config, value: final result
self.model_config.num_envs = self.inf_batch_size
self.model_config.noptepochs = epochs_per_update
self.model_config.nminibatches = minibatch_size
self.send_trial_callback = None
logger.info('=== finished PPOTuner initialization')
def _process_one_nas_space(self, block_name, block_space):
"""
process nas space to determine observation space and action space
Parameters:
----------
block_name: the name of the mutable block
block_space: search space of this mutable block
Returns:
----------
actions_spaces: list of the space of each action
actions_to_config: the mapping from action to generated configuration
"""
actions_spaces = []
actions_to_config = []
block_arch_temp = {}
for l_name, layer in block_space.items():
chosen_layer_temp = {}
if len(layer['layer_choice']) > 1:
actions_spaces.append(layer['layer_choice'])
actions_to_config.append((block_name, l_name, 'chosen_layer'))
chosen_layer_temp['chosen_layer'] = None
else:
assert len(layer['layer_choice']) == 1
chosen_layer_temp['chosen_layer'] = layer['layer_choice'][0]
if layer['optional_input_size'] not in [0, 1, [0, 1]]:
raise ValueError('Optional_input_size can only be 0, 1, or [0, 1], but the pecified one is %s'
% (layer['optional_input_size']))
if isinstance(layer['optional_input_size'], list):
actions_spaces.append(["None", *layer['optional_inputs']])
actions_to_config.append((block_name, l_name, 'chosen_inputs'))
chosen_layer_temp['chosen_inputs'] = None
elif layer['optional_input_size'] == 1:
actions_spaces.append(layer['optional_inputs'])
actions_to_config.append((block_name, l_name, 'chosen_inputs'))
chosen_layer_temp['chosen_inputs'] = None
elif layer['optional_input_size'] == 0:
chosen_layer_temp['chosen_inputs'] = []
else:
raise ValueError('invalid type and value of optional_input_size')
block_arch_temp[l_name] = chosen_layer_temp
self.chosen_arch_template[block_name] = block_arch_temp
return actions_spaces, actions_to_config
def _process_nas_space(self, search_space):
"""
process nas search space to get action/observation space
"""
actions_spaces = []
actions_to_config = []
for b_name, block in search_space.items():
if block['_type'] != 'mutable_layer':
raise ValueError('PPOTuner only accept mutable_layer type in search space, but the current one is %s'%(block['_type']))
block = block['_value']
act, act_map = self._process_one_nas_space(b_name, block)
actions_spaces.extend(act)
actions_to_config.extend(act_map)
# calculate observation space
dedup = {}
for step in actions_spaces:
for action in step:
dedup[action] = 1
full_act_space = [act for act, _ in dedup.items()]
assert len(full_act_space) == len(dedup)
observation_space = len(full_act_space)
nsteps = len(actions_spaces)
return actions_spaces, actions_to_config, full_act_space, observation_space, nsteps
def _generate_action_mask(self):
"""
different step could have different action space. to deal with this case, we merge all the
possible actions into one action space, and use mask to indicate available actions for each step
"""
two_masks = []
mask = []
for acts in self.actions_spaces:
one_mask = [0 for _ in range(len(self.full_act_space))]
for act in acts:
idx = self.full_act_space.index(act)
one_mask[idx] = 1
mask.append(one_mask)
two_masks.append(mask)
mask = []
for acts in self.actions_spaces:
one_mask = [-np.inf for _ in range(len(self.full_act_space))]
for act in acts:
idx = self.full_act_space.index(act)
one_mask[idx] = 0
mask.append(one_mask)
two_masks.append(mask)
return np.asarray(two_masks, dtype=np.float32)
def update_search_space(self, search_space):
"""
get search space, currently the space only includes that for NAS
Parameters:
----------
search_space: search space for NAS
Returns:
-------
no return
"""
logger.info('=== update search space %s', search_space)
assert self.search_space is None
self.search_space = search_space
assert self.model_config.observation_space is None
assert self.model_config.action_space is None
self.actions_spaces, self.actions_to_config, self.full_act_space, obs_space, nsteps = self._process_nas_space(search_space)
self.model_config.observation_space = spaces.Discrete(obs_space)
self.model_config.action_space = spaces.Discrete(obs_space)
self.model_config.nsteps = nsteps
# generate mask in numpy
mask = self._generate_action_mask()
assert self.model is None
self.model = PPOModel(self.model_config, mask)
def _actions_to_config(self, actions):
"""
given actions, to generate the corresponding trial configuration
"""
chosen_arch = copy.deepcopy(self.chosen_arch_template)
for cnt, act in enumerate(actions):
act_name = self.full_act_space[act]
(block_name, layer_name, key) = self.actions_to_config[cnt]
if key == 'chosen_inputs':
if act_name == 'None':
chosen_arch[block_name][layer_name][key] = []
else:
chosen_arch[block_name][layer_name][key] = [act_name]
elif key == 'chosen_layer':
chosen_arch[block_name][layer_name][key] = act_name
else:
raise ValueError('unrecognized key: {0}'.format(key))
return chosen_arch
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""
Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
"""
result = []
self.send_trial_callback = kwargs['st_callback']
for parameter_id in parameter_id_list:
had_exception = False
try:
logger.debug("generating param for %s", parameter_id)
res = self.generate_parameters(parameter_id, **kwargs)
except nni.NoMoreTrialError:
had_exception = True
if not had_exception:
result.append(res)
return result
def generate_parameters(self, parameter_id, **kwargs):
"""
generate parameters, if no trial configration for now, self.credit plus 1 to send the config later
"""
if self.first_inf:
self.trials_result = [None for _ in range(self.inf_batch_size)]
mb_obs, mb_actions, mb_values, mb_neglogpacs, mb_dones, last_values = self.model.inference(self.inf_batch_size)
self.trials_info = TrialsInfo(mb_obs, mb_actions, mb_values, mb_neglogpacs,
mb_dones, last_values, self.inf_batch_size)
self.first_inf = False
trial_info_idx, actions = self.trials_info.get_next()
if trial_info_idx is None:
self.credit += 1
self.param_ids.append(parameter_id)
raise nni.NoMoreTrialError('no more parameters now.')
self.running_trials[parameter_id] = trial_info_idx
new_config = self._actions_to_config(actions)
return new_config
def _next_round_inference(self):
"""
"""
self.finished_trials = 0
self.model.compute_rewards(self.trials_info, self.trials_result)
self.model.train(self.trials_info, self.inf_batch_size)
self.running_trials = {}
# generate new trials
self.trials_result = [None for _ in range(self.inf_batch_size)]
mb_obs, mb_actions, mb_values, mb_neglogpacs, mb_dones, last_values = self.model.inference(self.inf_batch_size)
self.trials_info = TrialsInfo(mb_obs, mb_actions, mb_values, mb_neglogpacs,
mb_dones, last_values, self.inf_batch_size)
# check credit and submit new trials
for _ in range(self.credit):
trial_info_idx, actions = self.trials_info.get_next()
if trial_info_idx is None:
logger.warning('No enough trial config, trials_per_update is suggested to be larger than trialConcurrency')
break
assert self.param_ids
param_id = self.param_ids.pop()
self.running_trials[param_id] = trial_info_idx
new_config = self._actions_to_config(actions)
self.send_trial_callback(param_id, new_config)
self.credit -= 1
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
receive trial's result. if the number of finished trials equals self.inf_batch_size, start the next update to
train the model
"""
trial_info_idx = self.running_trials.pop(parameter_id, None)
assert trial_info_idx is not None
value = extract_scalar_reward(value)
if self.optimize_mode == OptimizeMode.Minimize:
value = -value
self.trials_result[trial_info_idx] = value
self.finished_trials += 1
if self.finished_trials == self.inf_batch_size:
self._next_round_inference()
def trial_end(self, parameter_id, success, **kwargs):
"""
to deal with trial failure
"""
if not success:
if parameter_id not in self.running_trials:
logger.warning('The trial is failed, but self.running_trial does not have this trial')
return
trial_info_idx = self.running_trials.pop(parameter_id, None)
assert trial_info_idx is not None
# use mean of finished trials as the result of this failed trial
values = [val for val in self.trials_result if val is not None]
logger.warning('zql values: {0}'.format(values))
self.trials_result[trial_info_idx] = (sum(values) / len(values)) if len(values) > 0 else 0
self.finished_trials += 1
if self.finished_trials == self.inf_batch_size:
self._next_round_inference()
def import_data(self, data):
"""
Import additional data for tuning
Parameters
----------
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
logger.warning('PPOTuner cannot leverage imported data.')
enum34
gym
tensorflow
\ No newline at end of file
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
util functions
"""
import os
import random
import multiprocessing
import numpy as np
import tensorflow as tf
from gym.spaces import Discrete, Box, MultiDiscrete
def set_global_seeds(i):
"""set global seeds"""
rank = 0
myseed = i + 1000 * rank if i is not None else None
tf.set_random_seed(myseed)
np.random.seed(myseed)
random.seed(myseed)
def batch_to_seq(h, nbatch, nsteps, flat=False):
"""convert from batch to sequence"""
if flat:
h = tf.reshape(h, [nbatch, nsteps])
else:
h = tf.reshape(h, [nbatch, nsteps, -1])
return [tf.squeeze(v, [1]) for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)]
def seq_to_batch(h, flat=False):
"""convert from sequence to batch"""
shape = h[0].get_shape().as_list()
if not flat:
assert len(shape) > 1
nh = h[0].get_shape()[-1].value
return tf.reshape(tf.concat(axis=1, values=h), [-1, nh])
else:
return tf.reshape(tf.stack(values=h, axis=1), [-1])
def lstm(xs, ms, s, scope, nh, init_scale=1.0):
"""lstm cell"""
nbatch, nin = [v.value for v in xs[0].get_shape()]
with tf.variable_scope(scope):
wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0))
c, h = tf.split(axis=1, num_or_size_splits=2, value=s)
for idx, (x, m) in enumerate(zip(xs, ms)):
c = c*(1-m)
h = h*(1-m)
z = tf.matmul(x, wx) + tf.matmul(h, wh) + b
i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z)
i = tf.nn.sigmoid(i)
f = tf.nn.sigmoid(f)
o = tf.nn.sigmoid(o)
u = tf.tanh(u)
c = f*c + i*u
h = o*tf.tanh(c)
xs[idx] = h
s = tf.concat(axis=1, values=[c, h])
return xs, s
def lstm_model(nlstm=128, layer_norm=False):
"""
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
Note that the resulting function returns not only the output of the LSTM
(i.e. hidden state of lstm for each step in the sequence), but also a dictionary
with auxiliary tensors to be set as policy attributes.
Specifically,
S is a placeholder to feed current state (LSTM state has to be managed outside policy)
M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
initial_state is a numpy array containing initial lstm state (usually zeros)
state is the output LSTM state (to be fed into S at the next call)
An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example
Parameters:
----------
nlstm: int LSTM hidden state size
layer_norm: bool if True, layer-normalized version of LSTM is used
Returns:
-------
function that builds LSTM with a given input tensor / placeholder
"""
def network_fn(X, nenv=1, obs_size=-1):
with tf.variable_scope("emb", reuse=tf.AUTO_REUSE):
w_emb = tf.get_variable("w_emb", [obs_size+1, 32])
X = tf.nn.embedding_lookup(w_emb, X)
nbatch = X.shape[0]
nsteps = nbatch // nenv
h = tf.layers.flatten(X)
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
xs = batch_to_seq(h, nenv, nsteps)
ms = batch_to_seq(M, nenv, nsteps)
assert not layer_norm
h5, snew = lstm(xs, ms, S, scope='lstm', nh=nlstm)
h = seq_to_batch(h5)
initial_state = np.zeros(S.shape.as_list(), dtype=float)
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
return network_fn
def ortho_init(scale=1.0):
"""init approach"""
def _ortho_init(shape, dtype, partition_info=None):
#lasagne ortho init for tf
shape = tuple(shape)
if len(shape) == 2:
flat_shape = shape
elif len(shape) == 4: # assumes NHWC
flat_shape = (np.prod(shape[:-1]), shape[-1])
else:
raise NotImplementedError
a = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
q = u if u.shape == flat_shape else v # pick the one with the correct shape
q = q.reshape(shape)
return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
return _ortho_init
def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):
"""fully connected op"""
with tf.variable_scope(scope):
nin = x.get_shape()[1].value
w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias))
return tf.matmul(x, w)+b
def _check_shape(placeholder_shape, data_shape):
"""
check if two shapes are compatible (i.e. differ only by dimensions of size 1, or by the batch dimension)
"""
return True
# ================================================================
# Shape adjustment for feeding into tf placeholders
# ================================================================
def adjust_shape(placeholder, data):
"""
adjust shape of the data to the shape of the placeholder if possible.
If shape is incompatible, AssertionError is thrown
Parameters:
placeholder: tensorflow input placeholder
data: input data to be (potentially) reshaped to be fed into placeholder
Returns:
reshaped data
"""
if not isinstance(data, np.ndarray) and not isinstance(data, list):
return data
if isinstance(data, list):
data = np.array(data)
placeholder_shape = [x or -1 for x in placeholder.shape.as_list()]
assert _check_shape(placeholder_shape, data.shape), \
'Shape of data {} is not compatible with shape of the placeholder {}'.format(data.shape, placeholder_shape)
return np.reshape(data, placeholder_shape)
# ================================================================
# Global session
# ================================================================
def get_session(config=None):
"""Get default session or create one with a given config"""
sess = tf.get_default_session()
if sess is None:
sess = make_session(config=config, make_default=True)
return sess
def make_session(config=None, num_cpu=None, make_default=False, graph=None):
"""Returns a session that will use <num_cpu> CPU's only"""
if num_cpu is None:
num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count()))
if config is None:
config = tf.ConfigProto(
allow_soft_placement=True,
inter_op_parallelism_threads=num_cpu,
intra_op_parallelism_threads=num_cpu)
config.gpu_options.allow_growth = True
if make_default:
return tf.InteractiveSession(config=config, graph=graph)
else:
return tf.Session(config=config, graph=graph)
ALREADY_INITIALIZED = set()
def initialize():
"""Initialize all the uninitialized variables in the global scope."""
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
get_session().run(tf.variables_initializer(new_variables))
ALREADY_INITIALIZED.update(new_variables)
def observation_placeholder(ob_space, batch_size=None, name='Ob'):
"""
Create placeholder to feed observations into of the size appropriate to the observation space
Parameters:
----------
ob_space: gym.Space observation space
batch_size: int size of the batch to be fed into input. Can be left None in most cases.
name: str name of the placeholder
Returns:
-------
tensorflow placeholder tensor
"""
assert isinstance(ob_space, (Discrete, Box, MultiDiscrete)), \
'Can only deal with Discrete and Box observation spaces for now'
dtype = ob_space.dtype
if dtype == np.int8:
dtype = np.uint8
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
def explained_variance(ypred, y):
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
interpretation:
ev=0 => might as well have predicted zero
ev=1 => perfect prediction
ev<0 => worse than just predicting zero
"""
assert y.ndim == 1 and ypred.ndim == 1
vary = np.var(y)
return np.nan if vary == 0 else 1 - np.var(y-ypred)/vary
......@@ -43,7 +43,8 @@ _sequence_id = platform.get_sequence_id()
def get_next_parameter():
"""Returns a set of (hyper-)paremeters generated by Tuner."""
"""Returns a set of (hyper-)paremeters generated by Tuner.
Returns None if no more (hyper-)parameters can be generated by Tuner."""
global _params
_params = platform.get_next_parameter()
if _params is None:
......
......@@ -17,11 +17,10 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
import nni
from .recoverable import Recoverable
_logger = logging.getLogger(__name__)
......
......@@ -28,9 +28,9 @@ from io import BytesIO
import json
from unittest import TestCase, main
_trials = []
_end_trials = []
_trials = [ ]
_end_trials = [ ]
class NaiveAssessor(Assessor):
def assess_trial(self, trial_job_id, trial_history):
......@@ -47,12 +47,14 @@ class NaiveAssessor(Assessor):
_in_buf = BytesIO()
_out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
......
from unittest import TestCase, main
import tensorflow as tf
import torch
import torch.nn.functional as F
import nni.compression.tensorflow as tf_compressor
import nni.compression.torch as torch_compressor
def weight_variable(shape):
return tf.Variable(tf.truncated_normal(shape, stddev = 0.1))
def bias_variable(shape):
return tf.Variable(tf.constant(0.1, shape = shape))
def conv2d(x_input, w_matrix):
return tf.nn.conv2d(x_input, w_matrix, strides = [ 1, 1, 1, 1 ], padding = 'SAME')
def max_pool(x_input, pool_size):
size = [ 1, pool_size, pool_size, 1 ]
return tf.nn.max_pool(x_input, ksize = size, strides = size, padding = 'SAME')
class TfMnist:
def __init__(self):
images = tf.placeholder(tf.float32, [ None, 784 ], name = 'input_x')
labels = tf.placeholder(tf.float32, [ None, 10 ], name = 'input_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
self.images = images
self.labels = labels
self.keep_prob = keep_prob
self.train_step = None
self.accuracy = None
self.w1 = None
self.b1 = None
self.fcw1 = None
self.cross = None
with tf.name_scope('reshape'):
x_image = tf.reshape(images, [ -1, 28, 28, 1 ])
with tf.name_scope('conv1'):
w_conv1 = weight_variable([ 5, 5, 1, 32 ])
self.w1 = w_conv1
b_conv1 = bias_variable([ 32 ])
self.b1 = b_conv1
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
with tf.name_scope('pool1'):
h_pool1 = max_pool(h_conv1, 2)
with tf.name_scope('conv2'):
w_conv2 = weight_variable([ 5, 5, 32, 64 ])
b_conv2 = bias_variable([ 64 ])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
with tf.name_scope('pool2'):
h_pool2 = max_pool(h_conv2, 2)
with tf.name_scope('fc1'):
w_fc1 = weight_variable([ 7 * 7 * 64, 1024 ])
self.fcw1 = w_fc1
b_fc1 = bias_variable([ 1024 ])
h_pool2_flat = tf.reshape(h_pool2, [ -1, 7 * 7 * 64 ])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
with tf.name_scope('dropout'):
h_fc1_drop = tf.nn.dropout(h_fc1, 0.5)
with tf.name_scope('fc2'):
w_fc2 = weight_variable([ 1024, 10 ])
b_fc2 = bias_variable([ 10 ])
y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2
with tf.name_scope('loss'):
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = labels, logits = y_conv))
self.cross = cross_entropy
with tf.name_scope('adam_optimizer'):
self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
with tf.name_scope('accuracy'):
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
class TorchMnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim = 1)
class CompressorTestCase(TestCase):
def test_tf_pruner(self):
model = TfMnist()
configure_list = [{'sparsity':0.8, 'op_types':'default'}]
tf_compressor.LevelPruner(configure_list).compress_default_graph()
def test_tf_quantizer(self):
model = TfMnist()
tf_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress_default_graph()
def test_torch_pruner(self):
model = TorchMnist()
configure_list = [{'sparsity':0.8, 'op_types':'default'}]
torch_compressor.LevelPruner(configure_list).compress(model)
def test_torch_quantizer(self):
model = TorchMnist()
torch_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress(model)
if __name__ == '__main__':
main()
......@@ -32,7 +32,7 @@ from unittest import TestCase, main
class NaiveTuner(Tuner):
def __init__(self):
self.param = 0
self.trial_results = [ ]
self.trial_results = []
self.search_space = None
self.accept_customized_trials()
......@@ -57,12 +57,14 @@ class NaiveTuner(Tuner):
_in_buf = BytesIO()
_out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
......@@ -70,7 +72,6 @@ def _restore_io():
nni.protocol._out_file = _out_buf
class TunerTestCase(TestCase):
def test_tuner(self):
_reverse_io() # now we are sending to Tuner's incoming stream
......@@ -94,21 +95,20 @@ class TunerTestCase(TestCase):
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob')
_reverse_io() # now we are receiving from Tuner's outgoing stream
self._assert_params(0, 2, [ ], None)
self._assert_params(1, 4, [ ], None)
self._assert_params(0, 2, [], None)
self._assert_params(1, 4, [], None)
command, data = receive() # this one is customized
data = json.loads(data)
self.assertIs(command, CommandType.NewTrialJob)
self.assertEqual(data['parameter_id'], 2)
self.assertEqual(data['parameter_source'], 'customized')
self.assertEqual(data['parameters'], { 'param': -1 })
self.assertEqual(data['parameters'], {'param': -1})
self._assert_params(3, 6, [[1,4,11,False], [2,-1,22,True]], {'name':'SS0'})
self._assert_params(3, 6, [[1, 4, 11, False], [2, -1, 22, True]], {'name': 'SS0'})
self.assertEqual(len(_out_buf.read()), 0) # no more commands
def _assert_params(self, parameter_id, param, trial_results, search_space):
command, data = receive()
self.assertIs(command, CommandType.NewTrialJob)
......
import * as React from 'react';
import { Row, Col } from 'antd';
import axios from 'axios';
import { COLUMN, MANAGER_IP } from './static/const';
import { COLUMN } from './static/const';
import { EXPERIMENT, TRIALS } from './static/datamodel';
import './App.css';
import SlideBar from './components/SlideBar';
interface AppState {
interval: number;
whichPageToFresh: string;
columnList: Array<string>;
concurrency: number;
interval: number;
columnList: Array<string>;
experimentUpdateBroadcast: number;
trialsUpdateBroadcast: number;
}
class App extends React.Component<{}, AppState> {
public _isMounted: boolean;
constructor(props: {}) {
super(props);
this.state = {
interval: 10, // sendons
whichPageToFresh: '',
columnList: COLUMN,
concurrency: 1
};
}
private timerId: number | null;
changeInterval = (interval: number) => {
if (this._isMounted === true) {
this.setState(() => ({ interval: interval }));
constructor(props: {}) {
super(props);
this.state = {
interval: 10, // sendons
columnList: COLUMN,
experimentUpdateBroadcast: 0,
trialsUpdateBroadcast: 0,
};
}
}
changeFresh = (fresh: string) => {
// interval * 1000
if (this._isMounted === true) {
this.setState(() => ({ whichPageToFresh: fresh }));
async componentDidMount() {
await Promise.all([ EXPERIMENT.init(), TRIALS.init() ]);
this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 }));
this.timerId = window.setTimeout(this.refresh, this.state.interval * 1000);
}
}
changeColumn = (columnList: Array<string>) => {
if (this._isMounted === true) {
this.setState(() => ({ columnList: columnList }));
changeInterval = (interval: number) => {
this.setState({ interval: interval });
if (this.timerId === null && interval !== 0) {
window.setTimeout(this.refresh);
} else if (this.timerId !== null && interval === 0) {
window.clearTimeout(this.timerId);
}
}
}
changeConcurrency = (val: number) => {
if (this._isMounted === true) {
this.setState(() => ({ concurrency: val }));
// TODO: use local storage
changeColumn = (columnList: Array<string>) => {
this.setState({ columnList: columnList });
}
}
getConcurrency = () => {
axios(`${MANAGER_IP}/experiment`, {
method: 'GET'
})
.then(res => {
if (res.status === 200) {
const params = res.data.params;
if (this._isMounted) {
this.setState(() => ({ concurrency: params.trialConcurrency }));
}
render() {
const { interval, columnList, experimentUpdateBroadcast, trialsUpdateBroadcast } = this.state;
if (experimentUpdateBroadcast === 0 || trialsUpdateBroadcast === 0) {
return null; // TODO: render a loading page
}
const reactPropsChildren = React.Children.map(this.props.children, child =>
React.cloneElement(
// tslint:disable-next-line:no-any
child as React.ReactElement<any>, {
interval,
columnList, changeColumn: this.changeColumn,
experimentUpdateBroadcast,
trialsUpdateBroadcast,
})
);
return (
<Row className="nni" style={{ minHeight: window.innerHeight }}>
<Row className="header">
<Col span={1} />
<Col className="headerCon" span={22}>
<SlideBar changeInterval={this.changeInterval} />
</Col>
<Col span={1} />
</Row>
<Row className="contentBox">
<Row className="content">
{reactPropsChildren}
</Row>
</Row>
</Row>
);
}
private refresh = async () => {
const [ experimentUpdated, trialsUpdated ] = await Promise.all([ EXPERIMENT.update(), TRIALS.update() ]);
if (experimentUpdated) {
this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
}
if (trialsUpdated) {
this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 }));
}
});
}
componentDidMount() {
this._isMounted = true;
this.getConcurrency();
}
if ([ 'DONE', 'ERROR', 'STOPPED' ].includes(EXPERIMENT.status)) {
// experiment finished, refresh once more to ensure consistency
if (this.state.interval > 0) {
this.setState({ interval: 0 });
this.lastRefresh();
}
componentWillUnmount() {
this._isMounted = false;
}
render() {
const { interval, whichPageToFresh, columnList, concurrency } = this.state;
const reactPropsChildren = React.Children.map(this.props.children, child =>
React.cloneElement(
// tslint:disable-next-line:no-any
child as React.ReactElement<any>, {
interval, whichPageToFresh,
columnList, changeColumn: this.changeColumn,
concurrency, changeConcurrency: this.changeConcurrency
})
);
return (
<Row className="nni" style={{ minHeight: window.innerHeight }}>
<Row className="header">
<Col span={1} />
<Col className="headerCon" span={22}>
<SlideBar changeInterval={this.changeInterval} changeFresh={this.changeFresh} />
</Col>
<Col span={1} />
</Row>
<Row className="contentBox">
<Row className="content">
{reactPropsChildren}
</Row>
</Row>
</Row>
);
}
} else if (this.state.interval !== 0) {
this.timerId = window.setTimeout(this.refresh, this.state.interval * 1000);
}
}
private async lastRefresh() {
await EXPERIMENT.update();
await TRIALS.update(true);
this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 }));
}
}
export default App;
......@@ -2,12 +2,13 @@ import * as React from 'react';
import { Row, Modal } from 'antd';
import ReactEcharts from 'echarts-for-react';
import IntermediateVal from '../public-child/IntermediateVal';
import { TRIALS } from '../../static/datamodel';
import '../../static/style/compare.scss';
import { TableObj, Intermedia, TooltipForIntermediate } from 'src/static/interface';
import { TableRecord, Intermedia, TooltipForIntermediate } from 'src/static/interface';
// the modal of trial compare
interface CompareProps {
compareRows: Array<TableObj>;
compareRows: Array<TableRecord>;
visible: boolean;
cancelFunc: () => void;
}
......@@ -25,11 +26,12 @@ class Compare extends React.Component<CompareProps, {}> {
const idsList: Array<string> = [];
Object.keys(compareRows).map(item => {
const temp = compareRows[item];
const trial = TRIALS.getTrial(temp.id);
trialIntermediate.push({
name: temp.id,
data: temp.description.intermediate,
data: trial.description.intermediate,
type: 'line',
hyperPara: temp.description.parameters
hyperPara: trial.description.parameters
});
idsList.push(temp.id);
});
......@@ -105,11 +107,12 @@ class Compare extends React.Component<CompareProps, {}> {
// render table column ---
initColumn = () => {
const { compareRows } = this.props;
const idList: Array<string> = [];
const sequenceIdList: Array<number> = [];
const durationList: Array<number> = [];
const compareRows = this.props.compareRows.map(tableRecord => TRIALS.getTrial(tableRecord.id));
const parameterList: Array<object> = [];
let parameterKeys: Array<string> = [];
if (compareRows.length !== 0) {
......@@ -147,7 +150,7 @@ class Compare extends React.Component<CompareProps, {}> {
const temp = compareRows[index];
return (
<td className="value" key={index}>
<IntermediateVal record={temp} />
<IntermediateVal trialId={temp.id} />
</td>
);
})}
......@@ -206,7 +209,7 @@ class Compare extends React.Component<CompareProps, {}> {
>
<Row className="compare-intermediate">
{this.intermediate()}
<Row className="compare-yAxis"># Intermediate</Row>
<Row className="compare-yAxis"># Intermediate result</Row>
</Row>
<Row>{this.initColumn()}</Row>
</Modal>
......
......@@ -58,7 +58,7 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> {
trialMessage: trialMessagesArr
};
if (this._isCompareMount === true) {
this.setState(() => ({ experiment: JSON.stringify(result, null, 4) }));
this.setState({ experiment: JSON.stringify(result, null, 4) });
}
}
}));
......
......@@ -51,13 +51,13 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
setDispatcher = (value: string) => {
if (this._isLogDrawer === true) {
this.setState(() => ({ isLoadispatcher: false, dispatcherLogStr: value }));
this.setState({ isLoadispatcher: false, dispatcherLogStr: value });
}
}
setNNImanager = (val: string) => {
if (this._isLogDrawer === true) {
this.setState(() => ({ isLoading: false, nniManagerLogStr: val }));
this.setState({ isLoading: false, nniManagerLogStr: val });
}
}
......
import * as React from 'react';
import axios from 'axios';
import { Row, Col } from 'antd';
import { MANAGER_IP } from '../static/const';
import { Experiment, TableObj, Parameters, TrialNumber } from '../static/interface';
import { getFinal } from '../static/function';
import { EXPERIMENT, TRIALS } from '../static/datamodel';
import { Trial } from '../static/model/trial';
import SuccessTable from './overview/SuccessTable';
import Title1 from './overview/Title1';
import Progressed from './overview/Progress';
import Accuracy from './overview/Accuracy';
import SearchSpace from './overview/SearchSpace';
import BasicInfo from './overview/BasicInfo';
import TrialPro from './overview/TrialProfile';
import TrialInfo from './overview/TrialProfile';
require('../static/style/overview.scss');
require('../static/style/logPath.scss');
......@@ -18,486 +16,70 @@ require('../static/style/accuracy.css');
require('../static/style/table.scss');
require('../static/style/overviewTitle.scss');
interface OverviewState {
tableData: Array<TableObj>;
experimentAPI: object;
searchSpace: object;
status: string;
errorStr: string;
trialProfile: Experiment;
option: object;
noData: string;
accuracyData: object;
bestAccuracy: number;
accNodata: string;
trialNumber: TrialNumber;
isTop10: boolean;
titleMaxbgcolor?: string;
titleMinbgcolor?: string;
// trial stdout is content(false) or link(true)
isLogCollection: boolean;
isMultiPhase: boolean;
interface OverviewProps {
experimentUpdateBroadcast: number;
trialsUpdateBroadcast: number;
}
interface OverviewProps {
interval: number; // user select
whichPageToFresh: string;
concurrency: number;
changeConcurrency: (val: number) => void;
interface OverviewState {
trialConcurrency: number;
metricGraphMode: 'max' | 'min';
}
class Overview extends React.Component<OverviewProps, OverviewState> {
public _isMounted = false;
public intervalID = 0;
public intervalProfile = 1;
constructor(props: OverviewProps) {
super(props);
this.state = {
searchSpace: {},
experimentAPI: {},
status: '',
errorStr: '',
trialProfile: {
id: '',
author: '',
experName: '',
runConcurren: 1,
maxDuration: 0,
execDuration: 0,
MaxTrialNum: 0,
startTime: 0,
tuner: {},
trainingServicePlatform: ''
},
tableData: [],
option: {},
noData: '',
// accuracy
accuracyData: {},
accNodata: '',
bestAccuracy: 0,
trialNumber: {
succTrial: 0,
failTrial: 0,
stopTrial: 0,
waitTrial: 0,
runTrial: 0,
unknowTrial: 0,
totalCurrentTrial: 0
},
isTop10: true,
isLogCollection: false,
isMultiPhase: false
trialConcurrency: EXPERIMENT.trialConcurrency,
metricGraphMode: (EXPERIMENT.optimizeMode === 'minimize' ? 'min' : 'max'),
};
}
// show session
showSessionPro = () => {
axios(`${MANAGER_IP}/experiment`, {
method: 'GET'
})
.then(res => {
if (res.status === 200) {
let sessionData = res.data;
let trialPro = [];
const tempara = sessionData.params;
const trainingPlatform = tempara.trainingServicePlatform;
// assessor clusterMeteData
const clusterMetaData = tempara.clusterMetaData;
const endTimenum = sessionData.endTime;
const assessor = tempara.assessor;
const advisor = tempara.advisor;
let optimizeMode = 'other';
if (tempara.tuner !== undefined) {
if (tempara.tuner.classArgs !== undefined) {
if (tempara.tuner.classArgs.optimize_mode !== undefined) {
optimizeMode = tempara.tuner.classArgs.optimize_mode;
}
}
}
// default logCollection is true
const logCollection = tempara.logCollection;
let expLogCollection: boolean = false;
const isMultiy: boolean = tempara.multiPhase !== undefined
? tempara.multiPhase : false;
if (optimizeMode !== undefined) {
if (optimizeMode === 'minimize') {
if (this._isMounted) {
this.setState({
isTop10: false,
titleMinbgcolor: '#999'
});
}
} else {
if (this._isMounted) {
this.setState({
isTop10: true,
titleMaxbgcolor: '#999'
});
}
}
}
if (logCollection !== undefined && logCollection !== 'none') {
expLogCollection = true;
}
trialPro.push({
id: sessionData.id,
author: tempara.authorName,
revision: sessionData.revision,
experName: tempara.experimentName,
runConcurren: tempara.trialConcurrency,
logDir: sessionData.logDir ? sessionData.logDir : 'undefined',
maxDuration: tempara.maxExecDuration,
execDuration: sessionData.execDuration,
MaxTrialNum: tempara.maxTrialNum,
startTime: sessionData.startTime,
endTime: endTimenum ? endTimenum : undefined,
trainingServicePlatform: trainingPlatform,
tuner: tempara.tuner,
assessor: assessor ? assessor : undefined,
advisor: advisor ? advisor : undefined,
clusterMetaData: clusterMetaData ? clusterMetaData : undefined,
logCollection: logCollection
});
// search space format loguniform max and min
const temp = tempara.searchSpace;
const searchSpace = temp !== undefined
? JSON.parse(temp) : {};
Object.keys(searchSpace).map(item => {
const key = searchSpace[item]._type;
let value = searchSpace[item]._value;
switch (key) {
case 'quniform':
case 'qnormal':
case 'qlognormal':
searchSpace[item]._value = [value[0], value[1]];
break;
default:
}
});
if (this._isMounted) {
this.setState({
experimentAPI: res.data,
trialProfile: trialPro[0],
searchSpace: searchSpace,
isLogCollection: expLogCollection,
isMultiPhase: isMultiy
});
}
}
});
this.checkStatus();
}
checkStatus = () => {
axios(`${MANAGER_IP}/check-status`, {
method: 'GET'
})
.then(res => {
if (res.status === 200) {
const errors = res.data.errors;
if (errors.length !== 0) {
if (this._isMounted) {
this.setState({
status: res.data.status,
errorStr: res.data.errors[0]
});
}
} else {
if (this._isMounted) {
this.setState({
status: res.data.status,
});
}
}
}
});
}
showTrials = () => {
this.isOffInterval();
axios(`${MANAGER_IP}/trial-jobs`, {
method: 'GET'
})
.then(res => {
if (res.status === 200) {
const tableData = res.data;
const topTableData: Array<TableObj> = [];
const profile: TrialNumber = {
succTrial: 0,
failTrial: 0,
stopTrial: 0,
waitTrial: 0,
runTrial: 0,
unknowTrial: 0,
totalCurrentTrial: 0
};
// currently totoal number
profile.totalCurrentTrial = tableData.length;
Object.keys(tableData).map(item => {
switch (tableData[item].status) {
case 'WAITING':
profile.waitTrial += 1;
break;
case 'UNKNOWN':
profile.unknowTrial += 1;
break;
case 'FAILED':
profile.failTrial += 1;
break;
case 'RUNNING':
profile.runTrial += 1;
break;
case 'USER_CANCELED':
case 'SYS_CANCELED':
case 'EARLY_STOPPED':
profile.stopTrial += 1;
break;
case 'SUCCEEDED':
profile.succTrial += 1;
const desJobDetail: Parameters = {
parameters: {},
intermediate: [],
multiProgress: 1
};
const duration = (tableData[item].endTime - tableData[item].startTime) / 1000;
const acc = getFinal(tableData[item].finalMetricData);
// if hyperparameters is undefine, show error message, else, show parameters value
const tempara = tableData[item].hyperParameters;
if (tempara !== undefined) {
const tempLength = tempara.length;
const parameters = JSON.parse(tempara[tempLength - 1]).parameters;
desJobDetail.multiProgress = tempara.length;
if (typeof parameters === 'string') {
desJobDetail.parameters = JSON.parse(parameters);
} else {
desJobDetail.parameters = parameters;
}
} else {
desJobDetail.parameters = { error: 'This trial\'s parameters are not available.' };
}
if (tableData[item].logPath !== undefined) {
desJobDetail.logPath = tableData[item].logPath;
}
topTableData.push({
key: topTableData.length,
sequenceId: tableData[item].sequenceId,
id: tableData[item].id,
duration: duration,
status: tableData[item].status,
acc: acc,
description: desJobDetail
});
break;
default:
}
});
// choose top10 or lowest10
const { isTop10 } = this.state;
if (isTop10 === true) {
topTableData.sort((a: TableObj, b: TableObj) => {
if (a.acc !== undefined && b.acc !== undefined) {
return JSON.parse(b.acc.default) - JSON.parse(a.acc.default);
} else {
return NaN;
}
});
} else {
topTableData.sort((a: TableObj, b: TableObj) => {
if (a.acc !== undefined && b.acc !== undefined) {
return JSON.parse(a.acc.default) - JSON.parse(b.acc.default);
} else {
return NaN;
}
});
}
topTableData.length = Math.min(10, topTableData.length);
let bestDefaultMetric = 0;
if (topTableData[0] !== undefined) {
if (topTableData[0].acc !== undefined) {
bestDefaultMetric = JSON.parse(topTableData[0].acc.default);
}
}
if (this._isMounted) {
this.setState({
tableData: topTableData,
trialNumber: profile,
bestAccuracy: bestDefaultMetric
});
}
this.checkStatus();
// draw accuracy
this.drawPointGraph();
}
});
}
// trial accuracy graph Default Metric
drawPointGraph = () => {
const { tableData } = this.state;
const sourcePoint = JSON.parse(JSON.stringify(tableData));
sourcePoint.sort((a: TableObj, b: TableObj) => {
if (a.sequenceId !== undefined && b.sequenceId !== undefined) {
return a.sequenceId - b.sequenceId;
} else {
return NaN;
}
});
const accarr: Array<number> = [];
const indexarr: Array<number> = [];
Object.keys(sourcePoint).map(item => {
const items = sourcePoint[item];
if (items.acc !== undefined) {
accarr.push(items.acc.default);
indexarr.push(items.sequenceId);
}
});
const accOption = {
// support max show 0.0000000
grid: {
left: 67,
right: 40
},
tooltip: {
trigger: 'item'
},
xAxis: {
name: 'Trial',
type: 'category',
data: indexarr
},
yAxis: {
name: 'Default metric',
type: 'value',
scale: true,
data: accarr
},
series: [{
symbolSize: 6,
type: 'scatter',
data: accarr
}]
};
if (this._isMounted) {
this.setState({ accuracyData: accOption }, () => {
if (accarr.length === 0) {
this.setState({
accNodata: 'No data'
});
} else {
this.setState({
accNodata: ''
});
}
});
}
}
clickMaxTop = (event: React.SyntheticEvent<EventTarget>) => {
event.stopPropagation();
// #999 panel active bgcolor; #b3b3b3 as usual
this.setState(() => ({ isTop10: true, titleMaxbgcolor: '#999', titleMinbgcolor: '#b3b3b3' }));
this.showTrials();
this.setState({ metricGraphMode: 'max' });
}
clickMinTop = (event: React.SyntheticEvent<EventTarget>) => {
event.stopPropagation();
this.setState(() => ({ isTop10: false, titleMaxbgcolor: '#b3b3b3', titleMinbgcolor: '#999' }));
this.showTrials();
}
isOffInterval = () => {
const { status } = this.state;
const { interval } = this.props;
if (status === 'DONE' || status === 'ERROR' || status === 'STOPPED' ||
interval === 0
) {
window.clearInterval(this.intervalID);
window.clearInterval(this.intervalProfile);
return;
}
this.setState({ metricGraphMode: 'min' });
}
componentWillReceiveProps(nextProps: OverviewProps) {
const { interval, whichPageToFresh } = nextProps;
window.clearInterval(this.intervalID);
window.clearInterval(this.intervalProfile);
if (whichPageToFresh.includes('/oview')) {
this.showTrials();
this.showSessionPro();
}
if (interval !== 0) {
this.intervalID = window.setInterval(this.showTrials, interval * 1000);
this.intervalProfile = window.setInterval(this.showSessionPro, interval * 1000);
}
changeConcurrency = (val: number) => {
this.setState({ trialConcurrency: val });
}
componentDidMount() {
this._isMounted = true;
const { interval } = this.props;
this.showTrials();
this.showSessionPro();
if (interval !== 0) {
this.intervalID = window.setInterval(this.showTrials, interval * 1000);
this.intervalProfile = window.setInterval(this.showSessionPro, interval * 1000);
}
}
render() {
const { trialConcurrency, metricGraphMode } = this.state;
const { experimentUpdateBroadcast } = this.props;
componentWillUnmount() {
this._isMounted = false;
window.clearInterval(this.intervalID);
window.clearInterval(this.intervalProfile);
}
const searchSpace = this.convertSearchSpace();
render() {
const bestTrials = this.findBestTrials();
const bestAccuracy = bestTrials.length > 0 ? bestTrials[0].accuracy! : NaN;
const accuracyGraphData = this.generateAccuracyGraph(bestTrials);
const noDataMessage = bestTrials.length > 0 ? '' : 'No data';
const {
trialProfile, searchSpace, tableData, accuracyData,
accNodata, status, errorStr, trialNumber, bestAccuracy, isMultiPhase,
titleMaxbgcolor, titleMinbgcolor, isLogCollection, experimentAPI
} = this.state;
const { concurrency } = this.props;
trialProfile.runConcurren = concurrency;
Object.keys(experimentAPI).map(item => {
if (item === 'params') {
const temp = experimentAPI[item];
Object.keys(temp).map(index => {
if (index === 'trialConcurrency') {
temp[index] = concurrency;
}
});
}
});
const titleMaxbgcolor = (metricGraphMode === 'max' ? '#999' : '#b3b3b3');
const titleMinbgcolor = (metricGraphMode === 'min' ? '#999' : '#b3b3b3');
return (
<div className="overview">
{/* status and experiment block */}
<Row>
<Title1 text="Experiment" icon="11.png" />
<BasicInfo trialProfile={trialProfile} status={status} />
<BasicInfo experimentUpdateBroadcast={experimentUpdateBroadcast} />
</Row>
<Row className="overMessage">
{/* status graph */}
<Col span={9} className="prograph overviewBoder cc">
<Title1 text="Status" icon="5.png" />
<Progressed
trialNumber={trialNumber}
trialProfile={trialProfile}
bestAccuracy={bestAccuracy}
status={status}
errors={errorStr}
concurrency={concurrency}
changeConcurrency={this.props.changeConcurrency}
concurrency={trialConcurrency}
changeConcurrency={this.changeConcurrency}
experimentUpdateBroadcast={experimentUpdateBroadcast}
/>
</Col>
{/* experiment parameters search space tuner assessor... */}
......@@ -512,7 +94,10 @@ class Overview extends React.Component<OverviewProps, OverviewState> {
<Row className="experiment">
{/* the scroll bar all the trial profile in the searchSpace div*/}
<div className="experiment searchSpace">
<TrialPro experiment={experimentAPI} />
<TrialInfo
experimentUpdateBroadcast={experimentUpdateBroadcast}
concurrency={trialConcurrency}
/>
</div>
</Row>
</Col>
......@@ -541,24 +126,79 @@ class Overview extends React.Component<OverviewProps, OverviewState> {
<Col span={8} className="overviewBoder">
<Row className="accuracy">
<Accuracy
accuracyData={accuracyData}
accNodata={accNodata}
accuracyData={accuracyGraphData}
accNodata={noDataMessage}
height={324}
/>
</Row>
</Col>
<Col span={16} id="succeTable">
<SuccessTable
tableSource={tableData}
multiphase={isMultiPhase}
logCollection={isLogCollection}
trainingPlatform={trialProfile.trainingServicePlatform}
/>
<SuccessTable trialIds={bestTrials.map(trial => trial.info.id)}/>
</Col>
</Row>
</Row>
</div>
);
}
private convertSearchSpace(): object {
const searchSpace = Object.assign({}, EXPERIMENT.searchSpace);
Object.keys(searchSpace).map(item => {
const key = searchSpace[item]._type;
let value = searchSpace[item]._value;
switch (key) {
case 'quniform':
case 'qnormal':
case 'qlognormal':
searchSpace[item]._value = [value[0], value[1]];
break;
default:
}
});
return searchSpace;
}
private findBestTrials(): Trial[] {
let bestTrials = TRIALS.sort();
if (this.state.metricGraphMode === 'max') {
bestTrials.reverse().splice(10);
} else {
bestTrials.splice(10);
}
return bestTrials;
}
private generateAccuracyGraph(bestTrials: Trial[]): object {
const xSequence = bestTrials.map(trial => trial.sequenceId);
const ySequence = bestTrials.map(trial => trial.accuracy);
return {
// support max show 0.0000000
grid: {
left: 67,
right: 40
},
tooltip: {
trigger: 'item'
},
xAxis: {
name: 'Trial',
type: 'category',
data: xSequence
},
yAxis: {
name: 'Default metric',
type: 'value',
scale: true,
data: ySequence
},
series: [{
symbolSize: 6,
type: 'scatter',
data: ySequence
}]
};
}
}
export default Overview;
......@@ -26,7 +26,6 @@ interface SliderState {
interface SliderProps extends FormComponentProps {
changeInterval: (value: number) => void;
changeFresh: (value: string) => void;
}
interface EventPer {
......@@ -35,7 +34,6 @@ interface EventPer {
class SlideBar extends React.Component<SliderProps, SliderState> {
public _isMounted = false;
public divMenu: HTMLDivElement | null;
public selectHTML: Select | null;
......@@ -57,32 +55,26 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
method: 'GET'
})
.then(res => {
if (res.status === 200 && this._isMounted) {
if (res.status === 200) {
this.setState({ version: res.data });
}
});
}
handleMenuClick = (e: EventPer) => {
if (this._isMounted) { this.setState({ menuVisible: false }); }
this.setState({ menuVisible: false });
switch (e.key) {
// to see & download experiment parameters
case '1':
if (this._isMounted === true) {
this.setState(() => ({ isvisibleExperimentDrawer: true }));
}
this.setState({ isvisibleExperimentDrawer: true });
break;
// to see & download nnimanager log
case '2':
if (this._isMounted === true) {
this.setState(() => ({ activeKey: 'nnimanager', isvisibleLogDrawer: true }));
}
this.setState({ activeKey: 'nnimanager', isvisibleLogDrawer: true });
break;
// to see & download dispatcher log
case '3':
if (this._isMounted === true) {
this.setState(() => ({ isvisibleLogDrawer: true, activeKey: 'dispatcher' }));
}
this.setState({ isvisibleLogDrawer: true, activeKey: 'dispatcher' });
break;
case 'close':
case '10':
......@@ -96,13 +88,10 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
}
handleVisibleChange = (flag: boolean) => {
if (this._isMounted === true) {
this.setState({ menuVisible: flag });
}
this.setState({ menuVisible: flag });
}
getInterval = (value: string) => {
if (value === 'close') {
this.props.changeInterval(0);
} else {
......@@ -203,13 +192,9 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
fresh = (event: React.SyntheticEvent<EventTarget>) => {
event.preventDefault();
event.stopPropagation();
if (this._isMounted) {
this.setState({ isdisabledFresh: true }, () => {
const whichPage = window.location.pathname;
this.props.changeFresh(whichPage);
setTimeout(() => { this.setState(() => ({ isdisabledFresh: false })); }, 1000);
});
}
this.setState({ isdisabledFresh: true }, () => {
setTimeout(() => { this.setState({ isdisabledFresh: false }); }, 1000);
});
}
desktopHTML = () => {
......@@ -330,27 +315,18 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
}
// close log drawer (nnimanager.dispatcher)
closeLogDrawer = () => {
if (this._isMounted === true) {
this.setState(() => ({ isvisibleLogDrawer: false, activeKey: '' }));
}
this.setState({ isvisibleLogDrawer: false, activeKey: '' });
}
// close download experiment parameters drawer
closeExpDrawer = () => {
if (this._isMounted === true) {
this.setState(() => ({ isvisibleExperimentDrawer: false }));
}
this.setState({ isvisibleExperimentDrawer: false });
}
componentDidMount() {
this._isMounted = true;
this.getNNIversion();
}
componentWillUnmount() {
this._isMounted = false;
}
render() {
const mobile = (<MediaQuery maxWidth={884}>{this.mobileHTML()}</MediaQuery>);
const tablet = (<MediaQuery minWidth={885} maxWidth={1241}>{this.tabeltHTML()}</MediaQuery>);
......@@ -376,4 +352,4 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
}
}
export default Form.create<FormComponentProps>()(SlideBar);
\ No newline at end of file
export default Form.create<FormComponentProps>()(SlideBar);
import * as React from 'react';
import axios from 'axios';
import { MANAGER_IP } from '../static/const';
import { Row, Col, Tabs, Select, Button, Icon } from 'antd';
const Option = Select.Option;
import { TableObj, Parameters, ExperimentInfo } from '../static/interface';
import { getFinal } from '../static/function';
import { EXPERIMENT, TRIALS } from '../static/datamodel';
import { Trial } from '../static/model/trial';
import DefaultPoint from './trial-detail/DefaultMetricPoint';
import Duration from './trial-detail/Duration';
import Title1 from './overview/Title1';
......@@ -16,37 +14,22 @@ import '../static/style/trialsDetail.scss';
import '../static/style/search.scss';
interface TrialDetailState {
accSource: object;
accNodata: string;
tableListSource: Array<TableObj>;
searchResultSource: Array<TableObj>;
isHasSearch: boolean;
experimentLogCollection: boolean;
entriesTable: number; // table components val
entriesInSelect: string;
searchSpace: string;
isMultiPhase: boolean;
tablePageSize: number; // table components val
whichGraph: string;
hyperCounts: number; // user click the hyper-parameter counts
durationCounts: number;
intermediateCounts: number;
experimentInfo: ExperimentInfo;
searchFilter: string;
searchPlaceHolder: string;
searchType: string;
searchFilter: (trial: Trial) => boolean;
}
interface TrialsDetailProps {
interval: number;
whichPageToFresh: string;
columnList: Array<string>;
changeColumn: (val: Array<string>) => void;
experimentUpdateBroacast: number;
trialsUpdateBroadcast: number;
}
class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> {
public _isMounted = false;
public interAccuracy = 0;
public interTableList = 1;
public interAllTableList = 2;
public tableList: TableList | null;
......@@ -73,335 +56,67 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
constructor(props: TrialsDetailProps) {
super(props);
this.state = {
accSource: {},
accNodata: '',
tableListSource: [],
searchResultSource: [],
experimentLogCollection: false,
entriesTable: 20,
entriesInSelect: '20',
searchSpace: '',
tablePageSize: 20,
whichGraph: '1',
isHasSearch: false,
isMultiPhase: false,
hyperCounts: 0,
durationCounts: 0,
intermediateCounts: 0,
experimentInfo: {
platform: '',
optimizeMode: 'maximize'
},
searchFilter: 'id',
searchPlaceHolder: 'Search by id'
searchType: 'id',
searchFilter: trial => true,
};
}
getDetailSource = () => {
this.isOffIntervals();
axios
.all([
axios.get(`${MANAGER_IP}/trial-jobs`),
axios.get(`${MANAGER_IP}/metric-data`)
])
.then(axios.spread((res, res1) => {
if (res.status === 200 && res1.status === 200) {
const trialJobs = res.data;
const metricSource = res1.data;
const trialTable: Array<TableObj> = [];
Object.keys(trialJobs).map(item => {
let desc: Parameters = {
parameters: {},
intermediate: [],
multiProgress: 1
};
let duration = 0;
const id = trialJobs[item].id !== undefined
? trialJobs[item].id
: '';
const status = trialJobs[item].status !== undefined
? trialJobs[item].status
: '';
const begin = trialJobs[item].startTime;
const end = trialJobs[item].endTime;
if (begin) {
if (end) {
duration = (end - begin) / 1000;
} else {
duration = (new Date().getTime() - begin) / 1000;
}
}
const tempHyper = trialJobs[item].hyperParameters;
if (tempHyper !== undefined) {
const getPara = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
desc.multiProgress = tempHyper.length;
if (typeof getPara === 'string') {
desc.parameters = JSON.parse(getPara);
} else {
desc.parameters = getPara;
}
} else {
desc.parameters = { error: 'This trial\'s parameters are not available.' };
}
if (trialJobs[item].logPath !== undefined) {
desc.logPath = trialJobs[item].logPath;
}
const acc = getFinal(trialJobs[item].finalMetricData);
// deal with intermediate result list
const mediate: Array<number> = [];
Object.keys(metricSource).map(key => {
const items = metricSource[key];
if (items.trialJobId === id) {
// succeed trial, last intermediate result is final result
// final result format may be object
if (typeof JSON.parse(items.data) === 'object') {
mediate.push(JSON.parse(items.data).default);
} else {
mediate.push(JSON.parse(items.data));
}
}
});
desc.intermediate = mediate;
trialTable.push({
key: trialTable.length,
sequenceId: trialJobs[item].sequenceId,
id: id,
status: status,
duration: duration,
acc: acc,
description: desc,
startTime: begin,
endTime: (end !== undefined) ? end : undefined
});
});
// update search data result
const { searchResultSource, entriesInSelect } = this.state;
if (searchResultSource.length !== 0) {
const temp: Array<number> = [];
Object.keys(searchResultSource).map(index => {
temp.push(searchResultSource[index].id);
});
const searchResultList: Array<TableObj> = [];
for (let i = 0; i < temp.length; i++) {
Object.keys(trialTable).map(key => {
const item = trialTable[key];
if (item.id === temp[i]) {
searchResultList.push(item);
}
});
}
if (this._isMounted) {
this.setState(() => ({
searchResultSource: searchResultList
}));
}
}
if (this._isMounted) {
this.setState(() => ({ tableListSource: trialTable }));
}
if (entriesInSelect === 'all' && this._isMounted) {
this.setState(() => ({
entriesTable: trialTable.length
}));
}
}
}));
}
// search a trial by trial No. & trial id
searchTrial = (event: React.ChangeEvent<HTMLInputElement>) => {
const targetValue = event.target.value;
if (targetValue === '' || targetValue === ' ') {
const { tableListSource } = this.state;
if (this._isMounted) {
this.setState(() => ({
isHasSearch: false,
tableListSource: tableListSource,
}));
}
} else {
const { tableListSource, searchFilter } = this.state;
const searchResultList: Array<TableObj> = [];
Object.keys(tableListSource).map(key => {
const item = tableListSource[key];
switch (searchFilter) {
case 'id':
if (item.id.toUpperCase().includes(targetValue.toUpperCase())) {
searchResultList.push(item);
}
break;
case 'Trial No.':
if (item.sequenceId.toString() === targetValue) {
searchResultList.push(item);
}
break;
case 'status':
if (item.status.toUpperCase().includes(targetValue.toUpperCase())) {
searchResultList.push(item);
}
break;
case 'parameters':
const strParameters = JSON.stringify(item.description.parameters, null, 4);
if (strParameters.includes(targetValue)) {
searchResultList.push(item);
}
break;
default:
}
});
if (this._isMounted) {
this.setState(() => ({
searchResultSource: searchResultList,
isHasSearch: true
}));
}
}
}
// close timer
isOffIntervals = () => {
const { interval } = this.props;
if (interval === 0) {
window.clearInterval(this.interTableList);
let filter = (trial: Trial) => true;
if (!targetValue.trim()) {
this.setState({ searchFilter: filter });
return;
} else {
axios(`${MANAGER_IP}/check-status`, {
method: 'GET'
})
.then(res => {
if (res.status === 200 && this._isMounted) {
const expStatus = res.data.status;
if (expStatus === 'DONE' || expStatus === 'ERROR' || expStatus === 'STOPPED') {
window.clearInterval(this.interTableList);
return;
}
}
});
}
switch (this.state.searchType) {
case 'id':
filter = trial => trial.info.id.toUpperCase().includes(targetValue.toUpperCase());
break;
case 'Trial No.':
filter = trial => trial.info.sequenceId.toString() === targetValue;
break;
case 'status':
filter = trial => trial.info.status.toUpperCase().includes(targetValue.toUpperCase());
break;
case 'parameters':
// TODO: support filters like `x: 2` (instead of `"x": 2`)
filter = trial => JSON.stringify(trial.info.hyperParameters, null, 4).includes(targetValue);
break;
default:
alert(`Unexpected search filter ${this.state.searchType}`);
}
this.setState({ searchFilter: filter });
}
handleEntriesSelect = (value: string) => {
// user select isn't 'all'
if (value !== 'all') {
if (this._isMounted) {
this.setState(() => ({ entriesTable: parseInt(value, 10) }));
}
} else {
const { tableListSource } = this.state;
if (this._isMounted) {
this.setState(() => ({
entriesInSelect: 'all',
entriesTable: tableListSource.length
}));
}
}
handleTablePageSizeSelect = (value: string) => {
this.setState({ tablePageSize: value === 'all' ? -1 : parseInt(value, 10) });
}
handleWhichTabs = (activeKey: string) => {
// const which = JSON.parse(activeKey);
if (this._isMounted) {
this.setState(() => ({ whichGraph: activeKey }));
}
this.setState({ whichGraph: activeKey });
}
test = () => {
alert('TableList component was not properly initialized.');
}
getSearchFilter = (value: string) => {
updateSearchFilterType = (value: string) => {
// clear input value and re-render table
if (this.searchInput !== null) {
this.searchInput.value = '';
if (this._isMounted === true) {
this.setState(() => ({ isHasSearch: false }));
}
}
if (this._isMounted === true) {
this.setState(() => ({ searchFilter: value, searchPlaceHolder: `Search by ${value}` }));
}
}
// get and set logCollection val
checkExperimentPlatform = () => {
axios(`${MANAGER_IP}/experiment`, {
method: 'GET'
})
.then(res => {
if (res.status === 200) {
const trainingPlatform: string = res.data.params.trainingServicePlatform !== undefined
?
res.data.params.trainingServicePlatform
:
'';
// default logCollection is true
const logCollection = res.data.params.logCollection;
let expLogCollection: boolean = false;
const isMultiy: boolean = res.data.params.multiPhase !== undefined
? res.data.params.multiPhase : false;
const tuner = res.data.params.tuner;
// I'll set optimize is maximize if user not set optimize
let optimize: string = 'maximize';
if (tuner !== undefined) {
if (tuner.classArgs !== undefined) {
if (tuner.classArgs.optimize_mode !== undefined) {
if (tuner.classArgs.optimize_mode === 'minimize') {
optimize = 'minimize';
}
}
}
}
if (logCollection !== undefined && logCollection !== 'none') {
expLogCollection = true;
}
if (this._isMounted) {
this.setState({
experimentInfo: { platform: trainingPlatform, optimizeMode: optimize },
searchSpace: res.data.params.searchSpace,
experimentLogCollection: expLogCollection,
isMultiPhase: isMultiy
});
}
}
});
}
componentWillReceiveProps(nextProps: TrialsDetailProps) {
const { interval, whichPageToFresh } = nextProps;
window.clearInterval(this.interTableList);
if (interval !== 0) {
this.interTableList = window.setInterval(this.getDetailSource, interval * 1000);
}
if (whichPageToFresh.includes('/detail')) {
this.getDetailSource();
}
}
componentDidMount() {
this._isMounted = true;
const { interval } = this.props;
this.getDetailSource();
this.interTableList = window.setInterval(this.getDetailSource, interval * 1000);
this.checkExperimentPlatform();
}
componentWillUnmount() {
this._isMounted = false;
window.clearInterval(this.interTableList);
this.setState({ searchType: value });
}
render() {
const {
tableListSource, searchResultSource, isHasSearch, isMultiPhase,
entriesTable, experimentInfo, searchSpace, experimentLogCollection,
whichGraph, searchPlaceHolder
} = this.state;
const source = isHasSearch ? searchResultSource : tableListSource;
const { tablePageSize, whichGraph } = this.state;
const { columnList, changeColumn } = this.props;
const source = TRIALS.filter(this.state.searchFilter);
const trialIds = TRIALS.filter(this.state.searchFilter).map(trial => trial.id);
return (
<div>
<div className="trial" id="tabsty">
......@@ -409,10 +124,9 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<TabPane tab={this.titleOfacc} key="1">
<Row className="graph">
<DefaultPoint
height={402}
showSource={source}
whichGraph={whichGraph}
optimize={experimentInfo.optimizeMode}
trialIds={trialIds}
visible={whichGraph === '1'}
trialsUpdateBroadcast={this.props.trialsUpdateBroadcast}
/>
</Row>
</TabPane>
......@@ -420,7 +134,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<Row className="graph">
<Para
dataSource={source}
expSearchSpace={searchSpace}
expSearchSpace={JSON.stringify(EXPERIMENT.searchSpace)}
whichGraph={whichGraph}
/>
</Row>
......@@ -440,7 +154,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<span>Show</span>
<Select
className="entry"
onSelect={this.handleEntriesSelect}
onSelect={this.handleTablePageSizeSelect}
defaultValue="20"
>
<Option value="20">20</Option>
......@@ -464,7 +178,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
>
Compare
</Button>
<Select defaultValue="id" className="filter" onSelect={this.getSearchFilter}>
<Select defaultValue="id" className="filter" onSelect={this.updateSearchFilterType}>
<Option value="id">Id</Option>
<Option value="Trial No.">Trial No.</Option>
<Option value="status">Status</Option>
......@@ -473,7 +187,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<input
type="text"
className="search-input"
placeholder={searchPlaceHolder}
placeholder={`Search by ${this.state.searchType}`}
onChange={this.searchTrial}
style={{ width: 230 }}
ref={text => (this.searchInput) = text}
......@@ -481,14 +195,11 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
</Col>
</Row>
<TableList
entries={entriesTable}
tableSource={source}
isMultiPhase={isMultiPhase}
platform={experimentInfo.platform}
updateList={this.getDetailSource}
logCollection={experimentLogCollection}
pageSize={tablePageSize}
tableSource={source.map(trial => trial.tableRecord)}
columnList={columnList}
changeColumn={changeColumn}
trialsUpdateBroadcast={this.props.trialsUpdateBroadcast}
ref={(tabList) => this.tableList = tabList}
/>
</div>
......@@ -496,4 +207,4 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
}
}
export default TrialsDetail;
\ No newline at end of file
export default TrialsDetail;
import { Col, Row, Tooltip } from 'antd';
import * as React from 'react';
import {
Row, Col,
Tooltip
} from 'antd';
import { Experiment } from '../../static/interface';
import { EXPERIMENT } from '../../static/datamodel';
import { formatTimestamp } from '../../static/function';
interface BasicInfoProps {
trialProfile: Experiment;
status: string;
experimentUpdateBroadcast: number;
}
class BasicInfo extends React.Component<BasicInfoProps, {}> {
constructor(props: BasicInfoProps) {
super(props);
}
render() {
const { trialProfile } = this.props;
return (
<Row className="main">
<Col span={8} className="padItem basic">
<p>Name</p>
<div>{trialProfile.experName}</div>
<div>{EXPERIMENT.profile.params.experimentName}</div>
<p>ID</p>
<div>{trialProfile.id}</div>
<div>{EXPERIMENT.profile.id}</div>
</Col>
<Col span={8} className="padItem basic">
<p>Start time</p>
<div className="nowrap">
{new Date(trialProfile.startTime).toLocaleString('en-US')}
</div>
<div className="nowrap">{formatTimestamp(EXPERIMENT.profile.startTime)}</div>
<p>End time</p>
<div className="nowrap">
{
trialProfile.endTime
?
new Date(trialProfile.endTime).toLocaleString('en-US')
:
'none'
}
</div>
<div className="nowrap">{formatTimestamp(EXPERIMENT.profile.endTime)}</div>
</Col>
<Col span={8} className="padItem basic">
<p>Log directory</p>
<div className="nowrap">
<Tooltip placement="top" title={trialProfile.logDir}>
{trialProfile.logDir}
<Tooltip placement="top" title={EXPERIMENT.profile.logDir || ''}>
{EXPERIMENT.profile.logDir || 'unknown'}
</Tooltip>
</div>
<p>Training platform</p>
<div className="nowrap">
{
trialProfile.trainingServicePlatform
?
trialProfile.trainingServicePlatform
:
'none'
}
</div>
<div className="nowrap">{EXPERIMENT.profile.params.trainingServicePlatform}</div>
</Col>
</Row>
);
}
}
export default BasicInfo;
\ No newline at end of file
export default BasicInfo;
import * as React from 'react';
import { Button, Row } from 'antd';
interface ConcurrencyInputProps {
value: number;
updateValue: (val: string) => void;
}
interface ConcurrencyInputStates {
editting: boolean;
}
class ConcurrencyInput extends React.Component<ConcurrencyInputProps, ConcurrencyInputStates> {
private input = React.createRef<HTMLInputElement>();
constructor(props: ConcurrencyInputProps) {
super(props);
this.state = { editting: false };
}
save = () => {
if (this.input.current !== null) {
this.props.updateValue(this.input.current.value);
this.setState({ editting: false });
}
}
cancel = () => {
this.setState({ editting: false });
}
edit = () => {
this.setState({ editting: true });
}
render() {
if (this.state.editting) {
return (
<Row className="inputBox">
<input
type="number"
className="concurrencyInput"
defaultValue={this.props.value.toString()}
ref={this.input}
/>
<Button
type="primary"
className="tableButton editStyle"
onClick={this.save}
>
Save
</Button>
<Button
type="primary"
onClick={this.cancel}
style={{ display: 'inline-block', marginLeft: 1 }}
className="tableButton editStyle"
>
Cancel
</Button>
</Row>
);
} else {
return (
<Row className="inputBox">
<input
type="number"
className="concurrencyInput"
disabled={true}
value={this.props.value}
/>
<Button
type="primary"
className="tableButton editStyle"
onClick={this.edit}
>
Edit
</Button>
</Row>
);
}
}
}
export default ConcurrencyInput;
import * as React from 'react';
import { Row, Col, Popover, Button, message } from 'antd';
import { Row, Col, Popover, message } from 'antd';
import axios from 'axios';
import { MANAGER_IP, CONTROLTYPE } from '../../static/const';
import { Experiment, TrialNumber } from '../../static/interface';
import { MANAGER_IP } from '../../static/const';
import { EXPERIMENT, TRIALS } from '../../static/datamodel';
import { convertTime } from '../../static/function';
import ConcurrencyInput from './NumInput';
import ProgressBar from './ProgressItem';
import LogDrawer from '../Modal/LogDrawer';
import '../../static/style/progress.scss';
import '../../static/style/probar.scss';
interface ProgressProps {
trialProfile: Experiment;
concurrency: number;
trialNumber: TrialNumber;
bestAccuracy: number;
status: string;
errors: string;
changeConcurrency: (val: number) => void;
experimentUpdateBroadcast: number;
}
interface ProgressState {
btnName: string;
isEnable: boolean;
userInputVal: string; // get user input
cancelSty: string;
isShowLogDrawer: boolean;
}
class Progressed extends React.Component<ProgressProps, ProgressState> {
public conInput: HTMLInputElement | null;
public _isMounted = false;
constructor(props: ProgressProps) {
super(props);
this.state = {
btnName: 'Edit',
isEnable: true,
userInputVal: this.props.trialProfile.runConcurren.toString(),
cancelSty: 'none',
isShowLogDrawer: false
};
}
editTrialConcurrency = () => {
const { btnName } = this.state;
if (this._isMounted) {
if (btnName === 'Edit') {
// user click edit
this.setState(() => ({
isEnable: false,
btnName: 'Save',
cancelSty: 'inline-block'
}));
} else {
// user click save button
axios(`${MANAGER_IP}/experiment`, {
method: 'GET'
})
.then(rese => {
if (rese.status === 200) {
const { userInputVal } = this.state;
const experimentFile = rese.data;
const trialConcurrency = experimentFile.params.trialConcurrency;
if (userInputVal !== undefined) {
if (userInputVal === trialConcurrency.toString() || userInputVal === '0') {
message.destroy();
message.info(
`trialConcurrency's value is ${trialConcurrency}, you did not modify it`, 2);
} else {
experimentFile.params.trialConcurrency = parseInt(userInputVal, 10);
// rest api, modify trial concurrency value
axios(`${MANAGER_IP}/experiment`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json;charset=utf-8'
},
data: experimentFile,
params: {
update_type: CONTROLTYPE[1]
}
}).then(res => {
if (res.status === 200) {
message.destroy();
message.success(`Update ${CONTROLTYPE[1].toLocaleLowerCase()}
successfully`);
this.props.changeConcurrency(parseInt(userInputVal, 10));
}
})
.catch(error => {
if (error.response.status === 500) {
if (error.response.data.error) {
message.error(error.response.data.error);
} else {
message.error(
`Update ${CONTROLTYPE[1].toLocaleLowerCase()} failed`);
}
}
});
// btn -> edit
this.setState(() => ({
btnName: 'Edit',
isEnable: true,
cancelSty: 'none'
}));
}
}
}
});
}
}
}
cancelFunction = () => {
const { trialProfile } = this.props;
if (this._isMounted) {
this.setState(
() => ({
btnName: 'Edit',
isEnable: true,
cancelSty: 'none',
}));
editTrialConcurrency = async (userInput: string) => {
if (!userInput.match(/^[1-9]\d*$/)) {
message.error('Please enter a positive integer!', 2);
return;
}
if (this.conInput !== null) {
this.conInput.value = trialProfile.runConcurren.toString();
const newConcurrency = parseInt(userInput, 10);
if (newConcurrency === this.props.concurrency) {
message.info(`Trial concurrency has not changed`, 2);
return;
}
}
getUserTrialConcurrency = (event: React.ChangeEvent<HTMLInputElement>) => {
const value = event.target.value;
if (value.match(/^[1-9]\d*$/) || value === '') {
this.setState(() => ({
userInputVal: value
}));
} else {
message.error('Please enter a positive integer!', 2);
if (this.conInput !== null) {
const { trialProfile } = this.props;
this.conInput.value = trialProfile.runConcurren.toString();
const newProfile = Object.assign({}, EXPERIMENT.profile);
newProfile.params.trialConcurrency = newConcurrency;
// rest api, modify trial concurrency value
try {
const res = await axios.put(`${MANAGER_IP}/experiment`, newProfile, {
params: { update_type: 'TRIAL_CONCURRENCY' }
});
if (res.status === 200) {
message.success(`Successfully updated trial concurrency`);
// NOTE: should we do this earlier in favor of poor networks?
this.props.changeConcurrency(newConcurrency);
}
} catch (error) {
if (error.response && error.response.data.error) {
message.error(`Failed to update trial concurrency\n${error.response.data.error}`);
} else if (error.response) {
message.error(`Failed to update trial concurrency\nServer responsed ${error.response.status}`);
} else if (error.message) {
message.error(`Failed to update trial concurrency\n${error.message}`);
} else {
message.error(`Failed to update trial concurrency\nUnknown error`);
}
}
}
isShowDrawer = () => {
if (this._isMounted === true) {
this.setState(() => ({ isShowLogDrawer: true }));
}
this.setState({ isShowLogDrawer: true });
}
closeDrawer = () => {
if (this._isMounted === true) {
this.setState(() => ({ isShowLogDrawer: false }));
}
this.setState({ isShowLogDrawer: false });
}
componentWillReceiveProps() {
const { trialProfile } = this.props;
if (this.conInput !== null) {
this.conInput.value = trialProfile.runConcurren.toString();
}
}
render() {
const { bestAccuracy } = this.props;
const { isShowLogDrawer } = this.state;
componentDidMount() {
this._isMounted = true;
}
const count = TRIALS.countStatus();
const stoppedCount = count.get('USER_CANCELED')! + count.get('SYS_CANCELED')! + count.get('EARLY_STOPPED')!;
const bar2 = count.get('RUNNING')! + count.get('SUCCEEDED')! + count.get('FAILED')! + stoppedCount;
componentWillUnmount() {
this._isMounted = false;
}
const bar2Percent = (bar2 / EXPERIMENT.profile.params.maxTrialNum) * 100;
const percent = (EXPERIMENT.profile.execDuration / EXPERIMENT.profile.params.maxExecDuration) * 100;
const remaining = convertTime(EXPERIMENT.profile.params.maxExecDuration - EXPERIMENT.profile.execDuration);
const maxDuration = convertTime(EXPERIMENT.profile.params.maxExecDuration);
const maxTrialNum = EXPERIMENT.profile.params.maxTrialNum;
const execDuration = convertTime(EXPERIMENT.profile.execDuration);
render() {
const { trialProfile, trialNumber, bestAccuracy, status, errors } = this.props;
const { isEnable, btnName, cancelSty, isShowLogDrawer } = this.state;
const bar2 = trialNumber.totalCurrentTrial - trialNumber.waitTrial - trialNumber.unknowTrial;
const bar2Percent = (bar2 / trialProfile.MaxTrialNum) * 100;
const percent = (trialProfile.execDuration / trialProfile.maxDuration) * 100;
const runDuration = convertTime(trialProfile.execDuration);
const temp = trialProfile.maxDuration - trialProfile.execDuration;
let remaining;
let errorContent;
if (temp < 0) {
remaining = '0';
} else {
remaining = convertTime(temp);
}
if (errors !== '') {
if (EXPERIMENT.error) {
errorContent = (
<div className="errors">
{errors}
{EXPERIMENT.error}
<div><a href="#" onClick={this.isShowDrawer}>Learn about</a></div>
</div>
);
......@@ -196,9 +103,9 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
<Row className="basic lineBasic">
<p>Status</p>
<div className="status">
<span className={status}>{status}</span>
<span className={EXPERIMENT.status}>{EXPERIMENT.status}</span>
{
status === 'ERROR'
EXPERIMENT.status === 'ERROR'
?
<Popover
placement="rightTop"
......@@ -216,26 +123,26 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
<ProgressBar
who="Duration"
percent={percent}
description={runDuration}
bgclass={status}
maxString={`Max duration: ${convertTime(trialProfile.maxDuration)}`}
description={execDuration}
bgclass={EXPERIMENT.status}
maxString={`Max duration: ${maxDuration}`}
/>
<ProgressBar
who="Trial numbers"
percent={bar2Percent}
description={bar2.toString()}
bgclass={status}
maxString={`Max trial number: ${trialProfile.MaxTrialNum}`}
bgclass={EXPERIMENT.status}
maxString={`Max trial number: ${maxTrialNum}`}
/>
<Row className="basic colorOfbasic mess">
<p>Best metric</p>
<div>{bestAccuracy.toFixed(6)}</div>
<div>{isNaN(bestAccuracy) ? 'N/A' : bestAccuracy.toFixed(6)}</div>
</Row>
<Row className="mess">
<Col span={6}>
<Row className="basic colorOfbasic">
<p>Spent</p>
<div>{convertTime(trialProfile.execDuration)}</div>
<div>{execDuration}</div>
</Row>
</Col>
<Col span={6}>
......@@ -247,54 +154,32 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
<Col span={12}>
{/* modify concurrency */}
<p>Concurrency</p>
<Row className="inputBox">
<input
type="number"
disabled={isEnable}
onChange={this.getUserTrialConcurrency}
className="concurrencyInput"
ref={(input) => this.conInput = input}
/>
<Button
type="primary"
className="tableButton editStyle"
onClick={this.editTrialConcurrency}
>{btnName}
</Button>
<Button
type="primary"
onClick={this.cancelFunction}
style={{ display: cancelSty, marginLeft: 1 }}
className="tableButton editStyle"
>
Cancel
</Button>
</Row>
<ConcurrencyInput value={this.props.concurrency} updateValue={this.editTrialConcurrency} />
</Col>
</Row>
<Row className="mess">
<Col span={6}>
<Row className="basic colorOfbasic">
<p>Running</p>
<div>{trialNumber.runTrial}</div>
<div>{count.get('RUNNING')}</div>
</Row>
</Col>
<Col span={6}>
<Row className="basic colorOfbasic">
<p>Succeeded</p>
<div>{trialNumber.succTrial}</div>
<div>{count.get('SUCCEEDED')}</div>
</Row>
</Col>
<Col span={6}>
<Row className="basic">
<p>Stopped</p>
<div>{trialNumber.stopTrial}</div>
<div>{stoppedCount}</div>
</Row>
</Col>
<Col span={6}>
<Row className="basic">
<p>Failed</p>
<div>{trialNumber.failTrial}</div>
<div>{count.get('FAILED')}</div>
</Row>
</Col>
</Row>
......@@ -309,4 +194,4 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
}
}
export default Progressed;
\ No newline at end of file
export default Progressed;
......@@ -2,131 +2,83 @@ import * as React from 'react';
import { Table } from 'antd';
import OpenRow from '../public-child/OpenRow';
import DefaultMetric from '../public-child/DefaultMetrc';
import { TableObj } from '../../static/interface';
import { TRIALS } from '../../static/datamodel';
import { TableRecord } from '../../static/interface';
import { convertDuration } from '../../static/function';
import '../../static/style/tableStatus.css';
import '../../static/style/openRow.scss';
interface SuccessTableProps {
tableSource: Array<TableObj>;
trainingPlatform: string;
logCollection: boolean;
multiphase: boolean;
trialIds: string[];
}
class SuccessTable extends React.Component<SuccessTableProps, {}> {
public _isMounted = false;
function openRow(record: TableRecord) {
return (
<OpenRow trialId={record.id} />
);
}
class SuccessTable extends React.Component<SuccessTableProps, {}> {
constructor(props: SuccessTableProps) {
super(props);
}
openRow = (record: TableObj) => {
const { trainingPlatform, logCollection, multiphase } = this.props;
return (
<OpenRow
trainingPlatform={trainingPlatform}
record={record}
logCollection={logCollection}
multiphase={multiphase}
/>
);
}
componentDidMount() {
this._isMounted = true;
}
componentWillUnmount() {
this._isMounted = false;
}
render() {
const { tableSource } = this.props;
let bgColor = '';
const columns = [{
title: 'Trial No.',
dataIndex: 'sequenceId',
key: 'sequenceId',
width: 140,
className: 'tableHead'
}, {
title: 'ID',
dataIndex: 'id',
key: 'id',
width: 60,
className: 'tableHead leftTitle',
render: (text: string, record: TableObj) => {
return (
<div>{record.id}</div>
);
},
}, {
title: 'Duration',
dataIndex: 'duration',
key: 'duration',
width: 140,
sorter: (a: TableObj, b: TableObj) => (a.duration as number) - (b.duration as number),
render: (text: string, record: TableObj) => {
let duration;
if (record.duration !== undefined) {
// duration is nagative number(-1) & 0-1
if (record.duration > 0 && record.duration < 1 || record.duration < 0) {
duration = `${record.duration}s`;
} else {
duration = convertDuration(record.duration);
}
} else {
duration = 0;
const columns = [
{
title: 'Trial No.',
dataIndex: 'sequenceId',
width: 140,
className: 'tableHead'
}, {
title: 'ID',
dataIndex: 'id',
width: 60,
className: 'tableHead leftTitle',
render: (text: string, record: TableRecord) => {
return (
<div>{record.id}</div>
);
},
}, {
title: 'Duration',
dataIndex: 'duration',
width: 140,
render: (text: string, record: TableRecord) => {
return (
<div className="durationsty"><div>{convertDuration(record.duration)}</div></div>
);
},
}, {
title: 'Status',
dataIndex: 'status',
width: 150,
className: 'tableStatus',
render: (text: string, record: TableRecord) => {
return (
<div className={`${record.status} commonStyle`}>{record.status}</div>
);
}
return (
<div className="durationsty"><div>{duration}</div></div>
);
},
}, {
title: 'Status',
dataIndex: 'status',
key: 'status',
width: 150,
className: 'tableStatus',
render: (text: string, record: TableObj) => {
bgColor = record.status;
return (
<div className={`${bgColor} commonStyle`}>
{record.status}
</div>
);
}
}, {
title: 'Default metric',
dataIndex: 'acc',
key: 'acc',
sorter: (a: TableObj, b: TableObj) => {
if (a.acc !== undefined && b.acc !== undefined) {
return JSON.parse(a.acc.default) - JSON.parse(b.acc.default);
} else {
return NaN;
}, {
title: 'Default metric',
dataIndex: 'accuracy',
render: (text: string, record: TableRecord) => {
return (
<DefaultMetric trialId={record.id} />
);
}
},
render: (text: string, record: TableObj) => {
return (
<DefaultMetric record={record} />
);
}
}];
];
return (
<div className="tabScroll" >
<Table
columns={columns}
expandedRowRender={this.openRow}
dataSource={tableSource}
expandedRowRender={openRow}
dataSource={TRIALS.table(this.props.trialIds)}
className="commonTableStyle"
pagination={false}
/>
</div >
</div>
);
}
}
......
import * as React from 'react';
import MonacoEditor from 'react-monaco-editor';
import { MONACO } from '../../static/const';
import { EXPERIMENT } from '../../static/datamodel';
interface TrialInfoProps {
experiment: object;
experimentUpdateBroadcast: number;
concurrency: number;
}
class TrialInfo extends React.Component<TrialInfoProps, {}> {
......@@ -12,32 +14,21 @@ class TrialInfo extends React.Component<TrialInfoProps, {}> {
super(props);
}
componentWillReceiveProps(nextProps: TrialInfoProps) {
const experiments = nextProps.experiment;
Object.keys(experiments).map(key => {
switch (key) {
case 'id':
case 'logDir':
case 'startTime':
case 'endTime':
experiments[key] = undefined;
break;
case 'params':
const params = experiments[key];
Object.keys(params).map(item => {
if (item === 'experimentName' || item === 'searchSpace'
|| item === 'trainingServicePlatform') {
params[item] = undefined;
}
});
break;
default:
render() {
const blacklist = [
'id', 'logDir', 'startTime', 'endTime',
'experimentName', 'searchSpace', 'trainingServicePlatform'
];
// tslint:disable-next-line:no-any
const filter = (key: string, val: any) => {
if (key === 'trialConcurrency') {
return this.props.concurrency;
}
});
}
return blacklist.includes(key) ? undefined : val;
};
const profile = JSON.stringify(EXPERIMENT.profile, filter, 2);
render() {
const { experiment } = this.props;
// FIXME: highlight not working?
return (
<div className="profile">
<MonacoEditor
......@@ -45,7 +36,7 @@ class TrialInfo extends React.Component<TrialInfoProps, {}> {
height="361"
language="json"
theme="vs-light"
value={JSON.stringify(experiment, null, 2)}
value={profile}
options={MONACO}
/>
</div>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment