Commit d6febf29 authored by suiguoxin's avatar suiguoxin
Browse files

Merge branch 'master' of git://github.com/microsoft/nni

parents 77c95479 c2179921
......@@ -20,15 +20,15 @@
import copy
import logging
import numpy as np
import os
import random
import statistics
import sys
import warnings
from enum import Enum, unique
from multiprocessing.dummy import Pool as ThreadPool
import numpy as np
import nni.metis_tuner.lib_constraint_summation as lib_constraint_summation
import nni.metis_tuner.lib_data as lib_data
import nni.metis_tuner.Regression_GMM.CreateModel as gmm_create_model
......@@ -42,8 +42,6 @@ from nni.utils import OptimizeMode, extract_scalar_reward
logger = logging.getLogger("Metis_Tuner_AutoML")
NONE_TYPE = ''
CONSTRAINT_LOWERBOUND = None
CONSTRAINT_UPPERBOUND = None
......@@ -93,7 +91,7 @@ class MetisTuner(Tuner):
self.space = None
self.no_resampling = no_resampling
self.no_candidates = no_candidates
self.optimize_mode = optimize_mode
self.optimize_mode = OptimizeMode(optimize_mode)
self.key_order = []
self.cold_start_num = cold_start_num
self.selection_num_starting_points = selection_num_starting_points
......@@ -174,7 +172,7 @@ class MetisTuner(Tuner):
return output
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Generate next parameter for trial
If the number of trial result is lower than cold start number,
metis will first random generate some parameters.
......@@ -205,7 +203,7 @@ class MetisTuner(Tuner):
return results
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""Tuner receive result from trial.
Parameters
......@@ -254,6 +252,9 @@ class MetisTuner(Tuner):
threshold_samplessize_resampling=50, no_candidates=False,
minimize_starting_points=None, minimize_constraints_fun=None):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
next_candidate = None
candidates = []
samples_size_all = sum([len(i) for i in samples_y])
......@@ -271,13 +272,12 @@ class MetisTuner(Tuner):
minimize_constraints_fun=minimize_constraints_fun)
if not lm_current:
return None
if no_candidates is False:
candidates.append({'hyperparameter': lm_current['hyperparameter'],
logger.info({'hyperparameter': lm_current['hyperparameter'],
'expected_mu': lm_current['expected_mu'],
'expected_sigma': lm_current['expected_sigma'],
'reason': "exploitation_gp"})
if no_candidates is False:
# ===== STEP 2: Get recommended configurations for exploration =====
results_exploration = gp_selection.selection(
"lc",
......@@ -290,34 +290,48 @@ class MetisTuner(Tuner):
if results_exploration is not None:
if _num_past_samples(results_exploration['hyperparameter'], samples_x, samples_y) == 0:
candidates.append({'hyperparameter': results_exploration['hyperparameter'],
temp_candidate = {'hyperparameter': results_exploration['hyperparameter'],
'expected_mu': results_exploration['expected_mu'],
'expected_sigma': results_exploration['expected_sigma'],
'reason': "exploration"})
'reason': "exploration"}
candidates.append(temp_candidate)
logger.info("DEBUG: 1 exploration candidate selected\n")
logger.info(temp_candidate)
else:
logger.info("DEBUG: No suitable exploration candidates were")
# ===== STEP 3: Get recommended configurations for exploitation =====
if samples_size_all >= threshold_samplessize_exploitation:
print("Getting candidates for exploitation...\n")
logger.info("Getting candidates for exploitation...\n")
try:
gmm = gmm_create_model.create_model(samples_x, samples_y_aggregation)
results_exploitation = gmm_selection.selection(
x_bounds,
x_types,
if ("discrete_int" in x_types) or ("range_int" in x_types):
results_exploitation = gmm_selection.selection(x_bounds, x_types,
gmm['clusteringmodel_good'],
gmm['clusteringmodel_bad'],
minimize_starting_points,
minimize_constraints_fun=minimize_constraints_fun)
else:
# If all parameters are of "range_continuous", let's use GMM to generate random starting points
results_exploitation = gmm_selection.selection_r(x_bounds, x_types,
gmm['clusteringmodel_good'],
gmm['clusteringmodel_bad'],
num_starting_points=self.selection_num_starting_points,
minimize_constraints_fun=minimize_constraints_fun)
if results_exploitation is not None:
if _num_past_samples(results_exploitation['hyperparameter'], samples_x, samples_y) == 0:
candidates.append({'hyperparameter': results_exploitation['hyperparameter'],\
'expected_mu': results_exploitation['expected_mu'],\
'expected_sigma': results_exploitation['expected_sigma'],\
'reason': "exploitation_gmm"})
temp_expected_mu, temp_expected_sigma = gp_prediction.predict(results_exploitation['hyperparameter'], gp_model['model'])
temp_candidate = {'hyperparameter': results_exploitation['hyperparameter'],
'expected_mu': temp_expected_mu,
'expected_sigma': temp_expected_sigma,
'reason': "exploitation_gmm"}
candidates.append(temp_candidate)
logger.info("DEBUG: 1 exploitation_gmm candidate selected\n")
logger.info(temp_candidate)
else:
logger.info("DEBUG: No suitable exploitation_gmm candidates were found\n")
......@@ -338,11 +352,13 @@ class MetisTuner(Tuner):
if results_outliers is not None:
for results_outlier in results_outliers:
if _num_past_samples(samples_x[results_outlier['samples_idx']], samples_x, samples_y) < max_resampling_per_x:
candidates.append({'hyperparameter': samples_x[results_outlier['samples_idx']],\
temp_candidate = {'hyperparameter': samples_x[results_outlier['samples_idx']],\
'expected_mu': results_outlier['expected_mu'],\
'expected_sigma': results_outlier['expected_sigma'],\
'reason': "resampling"})
'reason': "resampling"}
candidates.append(temp_candidate)
logger.info("DEBUG: %d re-sampling candidates selected\n")
logger.info(temp_candidate)
else:
logger.info("DEBUG: No suitable resampling candidates were found\n")
......
......@@ -18,7 +18,6 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import os
import logging
from collections import defaultdict
import json_tricks
......@@ -26,7 +25,7 @@ import json_tricks
from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
from .common import multi_thread_enabled
from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars
_logger = logging.getLogger(__name__)
......@@ -61,13 +60,19 @@ def _create_parameter_id():
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False):
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
if trial_job_id is not None:
ret['trial_job_id'] = trial_job_id
if parameter_index is not None:
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return json_tricks.dumps(ret)
class MsgDispatcher(MsgDispatcherBase):
......@@ -133,8 +138,13 @@ class MsgDispatcher(MsgDispatcherBase):
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
else:
pass
elif data['type'] == 'REQUEST_PARAMETER':
assert multi_phase_enabled()
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
......@@ -160,7 +170,13 @@ class MsgDispatcher(MsgDispatcherBase):
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
if multi_phase_enabled():
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
else:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
if multi_phase_enabled():
self.tuner.receive_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
from collections import defaultdict
import json_tricks
from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.assessor import AssessResult
_logger = logging.getLogger(__name__)
# Assessor global variables
_trial_history = defaultdict(dict)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials = set()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def _sort_history(history):
ret = [ ]
for i, _ in enumerate(history):
if i in history:
ret.append(history[i])
else:
break
return ret
# Tuner global variables
_next_parameter_id = 0
_trial_params = {}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids = set()
def _create_parameter_id():
global _next_parameter_id # pylint: disable=global-statement
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
if trial_job_id is not None:
ret['trial_job_id'] = trial_job_id
if parameter_index is not None:
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return json_tricks.dumps(ret)
class MultiPhaseMsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super(MultiPhaseMsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
_logger.debug('Assessor is not configured')
def load_checkpoint(self):
self.tuner.load_checkpoint()
if self.assessor is not None:
self.assessor.load_checkpoint()
def save_checkpoint(self):
self.tuner.save_checkpoint()
if self.assessor is not None:
self.assessor.save_checkpoint()
def handle_initialize(self, data):
'''
data is search space
'''
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
return True
def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
params_list = self.tuner.generate_multiple_parameters(ids)
assert len(ids) == len(params_list)
for i, _ in enumerate(ids):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
return True
def handle_update_search_space(self, data):
self.tuner.update_search_space(data)
return True
def handle_import_data(self, data):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self.tuner.import_data(data)
return True
def handle_add_customized_trial(self, data):
# data: parameters
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)
send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))
return True
def handle_report_metric_data(self, data):
trial_job_id = data['trial_job_id']
if data['type'] == 'FINAL':
id_ = data['parameter_id']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], data['value'], trial_job_id)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], data['value'], trial_job_id)
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
else:
pass
elif data['type'] == 'REQUEST_PARAMETER':
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
param = self.tuner.generate_parameters(param_id, trial_job_id)
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
return True
def handle_trial_end(self, data):
trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id)
if trial_job_id in _trial_history:
_trial_history.pop(trial_job_id)
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED', trial_job_id)
return True
def handle_import_data(self, data):
pass
def _handle_intermediate_metric_data(self, data):
if data['type'] != 'PERIODICAL':
return True
if self.assessor is None:
return True
trial_job_id = data['trial_job_id']
if trial_job_id in _ended_trials:
return True
history = _trial_history[trial_job_id]
history[data['sequence']] = data['value']
ordered_history = _sort_history(history)
if len(ordered_history) < data['sequence']: # no user-visible update since last time
return True
try:
result = self.assessor.assess_trial(trial_job_id, ordered_history)
except Exception as e:
_logger.exception('Assessor error')
if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad
elif not isinstance(result, AssessResult):
msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise RuntimeError(msg % type(result))
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
else:
_logger.debug('GOOD')
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
from nni.recoverable import Recoverable
_logger = logging.getLogger(__name__)
class MultiPhaseTuner(Recoverable):
# pylint: disable=no-self-use,unused-argument
def generate_parameters(self, parameter_id, trial_job_id=None):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: identifier of the parameter (int)
"""
raise NotImplementedError('Tuner: generate_parameters not implemented')
def generate_multiple_parameters(self, parameter_id_list):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
parameter_id_list: list of int
"""
return [self.generate_parameters(parameter_id) for parameter_id in parameter_id_list]
def receive_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial reports its final result. Must override.
parameter_id: identifier of the parameter (int)
parameters: object created by 'generate_parameters()'
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
raise NotImplementedError('Tuner: receive_trial_result not implemented')
def receive_customized_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: identifier of the parameter (int)
parameters: object created by user
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
_logger.info('Customized trial job %s ignored by tuner', parameter_id)
def trial_end(self, parameter_id, success, trial_job_id):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: identifier of the parameter (int)
success: True if the trial successfully completed; False if failed or terminated
trial_job_id: identifier of the trial (str)
"""
pass
def update_search_space(self, search_space):
"""Update the search space of tuner. Must override.
search_space: JSON object
"""
raise NotImplementedError('Tuner: update_search_space not implemented')
def import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def load_checkpoint(self):
"""Load the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def save_checkpoint(self):
"""Save the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def _on_exit(self):
pass
def _on_error(self):
pass
def import_data(self, data):
pass
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
from . import trial
def classic_mode(
mutable_id,
mutable_layer_id,
funcs,
funcs_args,
fixed_inputs,
optional_inputs,
optional_input_size):
'''Execute the chosen function and inputs directly.
In this mode, the trial code is only running the chosen subgraph (i.e., the chosen ops and inputs),
without touching the full model graph.'''
if trial._params is None:
trial.get_next_parameter()
mutable_block = trial.get_current_parameter(mutable_id)
chosen_layer = mutable_block[mutable_layer_id]["chosen_layer"]
chosen_inputs = mutable_block[mutable_layer_id]["chosen_inputs"]
real_chosen_inputs = [optional_inputs[input_name]
for input_name in chosen_inputs]
layer_out = funcs[chosen_layer](
[fixed_inputs, real_chosen_inputs], **funcs_args[chosen_layer])
return layer_out
def enas_mode(
mutable_id,
mutable_layer_id,
funcs,
funcs_args,
fixed_inputs,
optional_inputs,
optional_input_size,
tf):
'''For enas mode, we build the full model graph in trial but only run a subgraph。
This is implemented by masking inputs and branching ops.
Specifically, based on the received subgraph (through nni.get_next_parameter),
it can be known which inputs should be masked and which op should be executed.'''
name_prefix = "{}_{}".format(mutable_id, mutable_layer_id)
# store namespace
if 'name_space' not in globals():
global name_space
name_space = dict()
name_space[mutable_id] = True
name_space[name_prefix] = dict()
name_space[name_prefix]['funcs'] = list(funcs)
name_space[name_prefix]['optional_inputs'] = list(optional_inputs)
# create tensorflow variables as 1/0 signals used to form subgraph
if 'tf_variables' not in globals():
global tf_variables
tf_variables = dict()
name_for_optional_inputs = name_prefix + '_optional_inputs'
name_for_funcs = name_prefix + '_funcs'
tf_variables[name_prefix] = dict()
tf_variables[name_prefix]['optional_inputs'] = tf.get_variable(name_for_optional_inputs,
[len(
optional_inputs)],
dtype=tf.bool,
trainable=False)
tf_variables[name_prefix]['funcs'] = tf.get_variable(
name_for_funcs, [], dtype=tf.int64, trainable=False)
# get real values using their variable names
real_optional_inputs_value = [optional_inputs[name]
for name in name_space[name_prefix]['optional_inputs']]
real_func_value = [funcs[name]
for name in name_space[name_prefix]['funcs']]
real_funcs_args = [funcs_args[name]
for name in name_space[name_prefix]['funcs']]
# build tensorflow graph of geting chosen inputs by masking
real_chosen_inputs = tf.boolean_mask(
real_optional_inputs_value, tf_variables[name_prefix]['optional_inputs'])
# build tensorflow graph of different branches by using tf.case
branches = dict()
for func_id in range(len(funcs)):
func_output = real_func_value[func_id](
[fixed_inputs, real_chosen_inputs], **real_funcs_args[func_id])
branches[tf.equal(tf_variables[name_prefix]['funcs'],
func_id)] = lambda: func_output
layer_out = tf.case(branches, exclusive=True,
default=lambda: func_output)
return layer_out
def oneshot_mode(
mutable_id,
mutable_layer_id,
funcs,
funcs_args,
fixed_inputs,
optional_inputs,
optional_input_size,
tf):
'''Similar to enas mode, oneshot mode also builds the full model graph.
The difference is that oneshot mode does not receive subgraph.
Instead, it uses dropout to randomly dropout inputs and ops.'''
# NNI requires to get_next_parameter before report a result. But the parameter will not be used in this mode
if trial._params is None:
trial.get_next_parameter()
optional_inputs = list(optional_inputs.values())
inputs_num = len(optional_inputs)
# Calculate dropout rate according to the formular r^(1/k), where r is a hyper-parameter and k is the number of inputs
if inputs_num > 0:
rate = 0.01 ** (1 / inputs_num)
noise_shape = [inputs_num] + [1] * len(optional_inputs[0].get_shape())
optional_inputs = tf.nn.dropout(
optional_inputs, rate=rate, noise_shape=noise_shape)
optional_inputs = [optional_inputs[idx] for idx in range(inputs_num)]
layer_outs = [func([fixed_inputs, optional_inputs], **funcs_args[func_name])
for func_name, func in funcs.items()]
layer_out = tf.add_n(layer_outs)
return layer_out
def reload_tensorflow_variables(session, tf=None):
'''In Enas mode, this function reload every signal varaible created in `enas_mode` function so
the whole tensorflow graph will be changed into certain subgraph recerived from Tuner.
---------------
session: the tensorflow session created by users
tf: tensorflow module
'''
subgraph_from_tuner = trial.get_next_parameter()
for mutable_id, mutable_block in subgraph_from_tuner.items():
if mutable_id not in name_space:
continue
for mutable_layer_id, mutable_layer in mutable_block.items():
name_prefix = "{}_{}".format(mutable_id, mutable_layer_id)
# extract layer information from the subgraph sampled by tuner
chosen_layer = name_space[name_prefix]['funcs'].index(
mutable_layer["chosen_layer"])
chosen_inputs = [1 if inp in mutable_layer["chosen_inputs"]
else 0 for inp in name_space[name_prefix]['optional_inputs']]
# load these information into pre-defined tensorflow variables
tf_variables[name_prefix]['funcs'].load(chosen_layer, session)
tf_variables[name_prefix]['optional_inputs'].load(
chosen_inputs, session)
......@@ -123,7 +123,7 @@ class NetworkMorphismTuner(Tuner):
"""
self.search_space = search_space
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""
Returns a set of trial neural architecture, as a serializable object.
......@@ -152,7 +152,7 @@ class NetworkMorphismTuner(Tuner):
return json_out
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
""" Record an observation of the objective function.
Parameters
......
......@@ -151,7 +151,7 @@ class SMACTuner(Tuner):
else:
self.logger.warning('update search space is not supported.')
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""receive_trial_result
Parameters
......@@ -209,7 +209,7 @@ class SMACTuner(Tuner):
converted_dict[key] = value
return converted_dict
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""generate one instance of hyperparameters
Parameters
......@@ -232,7 +232,7 @@ class SMACTuner(Tuner):
self.total_data[parameter_id] = challenger
return self.convert_loguniform_categorical(challenger.get_dictionary())
def generate_multiple_parameters(self, parameter_id_list):
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""generate mutiple instances of hyperparameters
Parameters
......
......@@ -23,6 +23,7 @@ import random
from .env_vars import trial_env_vars
from . import trial
from .nas_utils import classic_mode, enas_mode, oneshot_mode
__all__ = [
......@@ -124,7 +125,9 @@ else:
funcs_args,
fixed_inputs,
optional_inputs,
optional_input_size):
optional_input_size,
mode='classic_mode',
tf=None):
'''execute the chosen function and inputs.
Below is an example of chosen function and inputs:
{
......@@ -144,14 +147,38 @@ else:
fixed_inputs:
optional_inputs: dict of optional inputs
optional_input_size: number of candidate inputs to be chosen
tf: tensorflow module
'''
mutable_block = _get_param(mutable_id)
chosen_layer = mutable_block[mutable_layer_id]["chosen_layer"]
chosen_inputs = mutable_block[mutable_layer_id]["chosen_inputs"]
real_chosen_inputs = [optional_inputs[input_name] for input_name in chosen_inputs]
layer_out = funcs[chosen_layer]([fixed_inputs, real_chosen_inputs], **funcs_args[chosen_layer])
return layer_out
if mode == 'classic_mode':
return classic_mode(mutable_id,
mutable_layer_id,
funcs,
funcs_args,
fixed_inputs,
optional_inputs,
optional_input_size)
elif mode == 'enas_mode':
assert tf is not None, 'Internal Error: Tensorflow should not be None in enas_mode'
return enas_mode(mutable_id,
mutable_layer_id,
funcs,
funcs_args,
fixed_inputs,
optional_inputs,
optional_input_size,
tf)
elif mode == 'oneshot_mode':
assert tf is not None, 'Internal Error: Tensorflow should not be None in oneshot_mode'
return oneshot_mode(mutable_id,
mutable_layer_id,
funcs,
funcs_args,
fixed_inputs,
optional_inputs,
optional_input_size,
tf)
else:
raise RuntimeError('Unrecognized mode: %s' % mode)
def _get_param(key):
if trial._params is None:
......
......@@ -30,14 +30,14 @@ _logger = logging.getLogger(__name__)
class Tuner(Recoverable):
# pylint: disable=no-self-use,unused-argument
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
raise NotImplementedError('Tuner: generate_parameters not implemented')
def generate_multiple_parameters(self, parameter_id_list):
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
......@@ -49,13 +49,13 @@ class Tuner(Recoverable):
for parameter_id in parameter_id_list:
try:
_logger.debug("generating param for {}".format(parameter_id))
res = self.generate_parameters(parameter_id)
res = self.generate_parameters(parameter_id, **kwargs)
except nni.NoMoreTrialError:
return result
result.append(res)
return result
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""Invoked when a trial reports its final result. Must override.
parameter_id: int
parameters: object created by 'generate_parameters()'
......@@ -63,7 +63,7 @@ class Tuner(Recoverable):
"""
raise NotImplementedError('Tuner: receive_trial_result not implemented')
def receive_customized_trial_result(self, parameter_id, parameters, value):
def receive_customized_trial_result(self, parameter_id, parameters, value, **kwargs):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: int
parameters: object created by user
......@@ -71,7 +71,7 @@ class Tuner(Recoverable):
"""
_logger.info('Customized trial job %s ignored by tuner', parameter_id)
def trial_end(self, parameter_id, success):
def trial_end(self, parameter_id, success, **kwargs):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: int
success: True if the trial successfully completed; False if failed or terminated
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
import random
from io import BytesIO
import nni
import nni.protocol
from nni.protocol import CommandType, send, receive
from nni.multi_phase.multi_phase_tuner import MultiPhaseTuner
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
from unittest import TestCase, main
class NaiveMultiPhaseTuner(MultiPhaseTuner):
'''
supports only choices
'''
def __init__(self):
self.search_space = None
def generate_parameters(self, parameter_id, trial_job_id=None):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
generated_parameters = {}
if self.search_space is None:
raise AssertionError('Search space not specified')
for k in self.search_space:
param = self.search_space[k]
if not param['_type'] == 'choice':
raise ValueError('Only choice type is supported')
param_values = param['_value']
generated_parameters[k] = param_values[random.randint(0, len(param_values)-1)]
logging.getLogger(__name__).debug(generated_parameters)
return generated_parameters
def receive_trial_result(self, parameter_id, parameters, value, trial_job_id):
logging.getLogger(__name__).debug('receive_trial_result: {},{},{},{}'.format(parameter_id, parameters, value, trial_job_id))
def receive_customized_trial_result(self, parameter_id, parameters, value, trial_job_id):
pass
def update_search_space(self, search_space):
self.search_space = search_space
_in_buf = BytesIO()
_out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._in_file = _in_buf
nni.protocol._out_file = _out_buf
def _test_tuner():
_reverse_io() # now we are sending to Tuner's incoming stream
send(CommandType.UpdateSearchSpace, "{\"learning_rate\": {\"_value\": [0.0001, 0.001, 0.002, 0.005, 0.01], \"_type\": \"choice\"}, \"optimizer\": {\"_value\": [\"Adam\", \"SGD\"], \"_type\": \"choice\"}}")
send(CommandType.RequestTrialJobs, '2')
send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10,"trial_job_id":"abc"}')
send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11,"trial_job_id":"abc"}')
send(CommandType.AddCustomizedTrialJob, '{"param":-1}')
send(CommandType.ReportMetricData, '{"parameter_id":2,"type":"FINAL","value":22,"trial_job_id":"abc"}')
send(CommandType.RequestTrialJobs, '1')
send(CommandType.TrialEnd, '{"trial_job_id":"abc"}')
_restore_io()
tuner = NaiveMultiPhaseTuner()
dispatcher = MultiPhaseMsgDispatcher(tuner)
dispatcher.run()
_reverse_io() # now we are receiving from Tuner's outgoing stream
command, data = receive() # this one is customized
print(command, data)
class MultiPhaseTestCase(TestCase):
def test_tuner(self):
_test_tuner()
if __name__ == '__main__':
main()
\ No newline at end of file
......@@ -38,7 +38,13 @@ class SmartParamTestCase(TestCase):
'test_smartparam/choice3/choice': '[1, 2]',
'test_smartparam/choice4/choice': '{"a", 2}',
'test_smartparam/func/function_choice': 'bar',
'test_smartparam/lambda_func/function_choice': "lambda: 2*3"
'test_smartparam/lambda_func/function_choice': "lambda: 2*3",
'mutable_block_66':{
'mutable_layer_0':{
'chosen_layer': 'conv2D(size=5)',
'chosen_inputs': ['y']
}
}
}
nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params }
......@@ -61,6 +67,13 @@ class SmartParamTestCase(TestCase):
val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func', key='test_smartparam/lambda_func/function_choice')
self.assertEqual(val, 6)
def test_mutable_layer(self):
layer_out = nni.mutable_layer('mutable_block_66',
'mutable_layer_0', {'conv2D(size=3)': conv2D, 'conv2D(size=5)': conv2D}, {'conv2D(size=3)':
{'size':3}, 'conv2D(size=5)': {'size':5}}, [100], {'x':1,'y':2}, 1, 'classic_mode')
self.assertEqual(layer_out, [100, 2, 5])
def foo():
return 'foo'
......@@ -68,6 +81,8 @@ def foo():
def bar():
return 'bar'
def conv2D(inputs, size=3):
return inputs[0] + inputs[1] + [size]
if __name__ == '__main__':
main()
......@@ -35,7 +35,7 @@ class NaiveTuner(Tuner):
self.trial_results = [ ]
self.search_space = None
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
# report Tuner's internal states to generated parameters,
# so we don't need to pause the main loop
self.param += 2
......@@ -45,7 +45,7 @@ class NaiveTuner(Tuner):
'search_space': self.search_space
}
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
reward = extract_scalar_reward(value)
self.trial_results.append((parameter_id, parameters['param'], reward, False))
......@@ -103,11 +103,9 @@ class TunerTestCase(TestCase):
command, data = receive() # this one is customized
data = json.loads(data)
self.assertIs(command, CommandType.NewTrialJob)
self.assertEqual(data, {
'parameter_id': 2,
'parameter_source': 'customized',
'parameters': { 'param': -1 }
})
self.assertEqual(data['parameter_id'], 2)
self.assertEqual(data['parameter_source'], 'customized')
self.assertEqual(data['parameters'], { 'param': -1 })
self._assert_params(3, 6, [[1,4,11,False], [2,-1,22,True]], {'name':'SS0'})
......
import * as React from 'react';
import { Row, Modal } from 'antd';
import ReactEcharts from 'echarts-for-react';
import IntermediateVal from '../public-child/IntermediateVal';
import '../../static/style/compare.scss';
import { TableObj, Intermedia, TooltipForIntermediate } from 'src/static/interface';
// the modal of trial compare
interface CompareProps {
compareRows: Array<TableObj>;
visible: boolean;
cancelFunc: () => void;
}
class Compare extends React.Component<CompareProps, {}> {
public _isCompareMount: boolean;
constructor(props: CompareProps) {
super(props);
}
intermediate = () => {
const { compareRows } = this.props;
const trialIntermediate: Array<Intermedia> = [];
const idsList: Array<string> = [];
Object.keys(compareRows).map(item => {
const temp = compareRows[item];
trialIntermediate.push({
name: temp.id,
data: temp.description.intermediate,
type: 'line',
hyperPara: temp.description.parameters
});
idsList.push(temp.id);
});
// find max intermediate number
trialIntermediate.sort((a, b) => { return (b.data.length - a.data.length); });
const legend: Array<string> = [];
// max length
const length = trialIntermediate[0] !== undefined ? trialIntermediate[0].data.length : 0;
const xAxis: Array<number> = [];
Object.keys(trialIntermediate).map(item => {
const temp = trialIntermediate[item];
legend.push(temp.name);
});
for (let i = 1; i <= length; i++) {
xAxis.push(i);
}
const option = {
tooltip: {
trigger: 'item',
enterable: true,
position: function (point: Array<number>, data: TooltipForIntermediate) {
if (data.dataIndex < length / 2) {
return [point[0], 80];
} else {
return [point[0] - 300, 80];
}
},
formatter: function (data: TooltipForIntermediate) {
const trialId = data.seriesName;
let obj = {};
const temp = trialIntermediate.find(key => key.name === trialId);
if (temp !== undefined) {
obj = temp.hyperPara;
}
return '<div class="tooldetailAccuracy">' +
'<div>Trial ID: ' + trialId + '</div>' +
'<div>Intermediate: ' + data.data + '</div>' +
'<div>Parameters: ' +
'<pre>' + JSON.stringify(obj, null, 4) + '</pre>' +
'</div>' +
'</div>';
}
},
grid: {
left: '5%',
top: 40,
containLabel: true
},
legend: {
data: idsList
},
xAxis: {
type: 'category',
name: 'Step',
boundaryGap: false,
data: xAxis
},
yAxis: {
type: 'value',
name: 'metric'
},
series: trialIntermediate
};
return (
<ReactEcharts
option={option}
style={{ width: '100%', height: 418, margin: '0 auto' }}
notMerge={true} // update now
/>
);
}
// render table column ---
initColumn = () => {
const { compareRows } = this.props;
const idList: Array<string> = [];
const durationList: Array<number> = [];
const parameterList: Array<object> = [];
let parameterKeys: Array<string> = [];
if (compareRows.length !== 0) {
parameterKeys = Object.keys(compareRows[0].description.parameters);
}
Object.keys(compareRows).map(item => {
const temp = compareRows[item];
idList.push(temp.id);
durationList.push(temp.duration);
parameterList.push(temp.description.parameters);
});
return (
<table className="compare">
<tbody>
<tr>
<td />
{Object.keys(idList).map(key => {
return (
<td className="value" key={key}>{idList[key]}</td>
);
})}
</tr>
<tr>
<td className="column">Default metric</td>
{Object.keys(compareRows).map(index => {
const temp = compareRows[index];
return (
<td className="value" key={index}>
<IntermediateVal record={temp}/>
</td>
);
})}
</tr>
<tr>
<td className="column">duration</td>
{Object.keys(durationList).map(index => {
return (
<td className="value" key={index}>{durationList[index]}</td>
);
})}
</tr>
{
Object.keys(parameterKeys).map(index => {
return (
<tr key={index}>
<td className="column" key={index}>{parameterKeys[index]}</td>
{
Object.keys(parameterList).map(key => {
return (
<td key={key} className="value">
{parameterList[key][parameterKeys[index]]}
</td>
);
})
}
</tr>
);
})
}
</tbody>
</table>
);
}
componentDidMount() {
this._isCompareMount = true;
}
componentWillUnmount() {
this._isCompareMount = false;
}
render() {
const { visible, cancelFunc } = this.props;
return (
<Modal
title="Compare trials"
visible={visible}
onCancel={cancelFunc}
footer={null}
destroyOnClose={true}
maskClosable={false}
width="90%"
>
<Row>{this.intermediate()}</Row>
<Row>{this.initColumn()}</Row>
</Modal>
);
}
}
export default Compare;
......@@ -353,8 +353,10 @@ class Overview extends React.Component<OverviewProps, OverviewState> {
const indexarr: Array<number> = [];
Object.keys(sourcePoint).map(item => {
const items = sourcePoint[item];
if (items.acc !== undefined) {
accarr.push(items.acc.default);
indexarr.push(items.sequenceId);
}
});
const accOption = {
// support max show 0.0000000
......
......@@ -29,7 +29,6 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
public _isMounted = false;
public divMenu: HTMLDivElement | null;
public countOfMenu: number = 0;
public selectHTML: Select | null;
constructor(props: SliderProps) {
......@@ -208,7 +207,6 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
}
menu = () => {
this.countOfMenu = 0;
return (
<Menu onClick={this.handleMenuClick}>
<Menu.Item key="1">Experiment Parameters</Menu.Item>
......@@ -223,7 +221,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
const { version } = this.state;
const feedBackLink = `https://github.com/Microsoft/nni/issues/new?labels=${version}`;
return (
<Menu onClick={this.handleMenuClick} mode="inline">
<Menu onClick={this.handleMenuClick} className="menuModal">
<Menu.Item key="overview"><Link to={'/oview'}>Overview</Link></Menu.Item>
<Menu.Item key="detail"><Link to={'/detail'}>Trials detail</Link></Menu.Item>
<Menu.Item key="fresh">
......@@ -250,18 +248,6 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
);
}
// nav bar <1299
showMenu = () => {
if (this.divMenu !== null) {
this.countOfMenu = this.countOfMenu + 1;
if (this.countOfMenu % 2 === 0) {
this.divMenu.setAttribute('class', 'hide');
} else {
this.divMenu.setAttribute('class', 'show');
}
}
}
select = () => {
return (
<Select
......@@ -322,7 +308,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
</li>
<li className="feedback">
<span className="fresh" onClick={this.fresh}>
<Icon type="sync"/><span>Fresh</span>
<Icon type="sync" /><span>Fresh</span>
</span>
<Dropdown
className="dropdown"
......@@ -350,8 +336,9 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
<MediaQuery query="(max-width: 1299px)">
<Row className="little">
<Col span={6} className="menu">
<Icon type="unordered-list" className="more" onClick={this.showMenu} />
<div ref={div => this.divMenu = div} className="hide">{this.navigationBar()}</div>
<Dropdown overlay={this.navigationBar()} trigger={['click']}>
<Icon type="unordered-list" className="more" />
</Dropdown>
</Col>
<Col span={10} className="logo">
<Link to={'/oview'}>
......
......@@ -378,7 +378,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
{/* trial table list */}
<Title1 text="Trial jobs" icon="6.png" />
<Row className="allList">
<Col span={12}>
<Col span={10}>
<span>Show</span>
<Select
className="entry"
......@@ -392,9 +392,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
</Select>
<span>entries</span>
</Col>
<Col span={12} className="right">
<Row>
<Col span={12}>
<Col span={14} className="right">
<Button
type="primary"
className="tableButton editStyle"
......@@ -402,8 +400,14 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
>
Add column
</Button>
</Col>
<Col span={12}>
<Button
type="primary"
className="tableButton editStyle mediateBtn"
// use child-component tableList's function, the function is in child-component.
onClick={this.tableList ? this.tableList.compareBtn : this.test}
>
Compare
</Button>
<Input
type="text"
placeholder="Search by id, trial No. or status"
......@@ -412,8 +416,6 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
/>
</Col>
</Row>
</Col>
</Row>
<TableList
entries={entriesTable}
tableSource={source}
......
import * as React from 'react';
import { Row, Col, Button, Switch } from 'antd';
import { TooltipForIntermediate, TableObj } from '../../static/interface';
import { TooltipForIntermediate, TableObj, Intermedia } from '../../static/interface';
import ReactEcharts from 'echarts-for-react';
require('echarts/lib/component/tooltip');
require('echarts/lib/component/title');
interface Intermedia {
name: string; // id
type: string;
data: Array<number | object>; // intermediate data
hyperPara: object; // each trial hyperpara value
}
interface IntermediateState {
detailSource: Array<TableObj>;
interSource: object;
......
......@@ -145,7 +145,8 @@ class Para extends React.Component<ParaProps, ParaState> {
const parallelAxis: Array<Dimobj> = [];
// search space range and specific value [only number]
for (let i = 0; i < dimName.length; i++) {
let i = 0;
for (i; i < dimName.length; i++) {
const searchKey = searchRange[dimName[i]];
switch (searchKey._type) {
case 'uniform':
......@@ -213,6 +214,13 @@ class Para extends React.Component<ParaProps, ParaState> {
}
}
parallelAxis.push({
dim: i,
name: 'default metric',
nameTextStyle: {
fontWeight: 700
}
});
if (lenOfDataSource === 0) {
const optionOfNull = {
parallelAxis,
......@@ -229,8 +237,8 @@ class Para extends React.Component<ParaProps, ParaState> {
const length = value.length;
if (length > 16) {
const temp = value.split('');
for (let i = 16; i < temp.length; i += 17) {
temp[i] += '\n';
for (let m = 16; m < temp.length; m += 17) {
temp[m] += '\n';
}
return temp.join('');
} else {
......
import * as React from 'react';
import axios from 'axios';
import ReactEcharts from 'echarts-for-react';
import { Row, Table, Button, Popconfirm, Modal, Checkbox } from 'antd';
import { Row, Table, Button, Popconfirm, Modal, Checkbox, Select } from 'antd';
const Option = Select.Option;
const CheckboxGroup = Checkbox.Group;
import { MANAGER_IP, trialJobStatus, COLUMN, COLUMN_INDEX } from '../../static/const';
import { convertDuration, intermediateGraphOption, killJob } from '../../static/function';
import { TableObj, TrialJob } from '../../static/interface';
import OpenRow from '../public-child/OpenRow';
import Compare from '../Modal/Compare';
import IntermediateVal from '../public-child/IntermediateVal'; // table default metric column
import '../../static/style/search.scss';
require('../../static/style/tableStatus.css');
......@@ -38,6 +40,12 @@ interface TableListState {
isObjFinal: boolean;
isShowColumn: boolean;
columnSelected: Array<string>; // user select columnKeys
selectRows: Array<TableObj>;
isShowCompareModal: boolean;
selectedRowKeys: string[] | number[];
intermediateData: Array<object>; // a trial's intermediate results (include dict)
intermediateId: string;
intermediateOtherKeys: Array<string>;
}
interface ColumnIndex {
......@@ -50,6 +58,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
public _isMounted = false;
public intervalTrialLog = 10;
public _trialId: string;
public tables: Table<TableObj> | null;
constructor(props: TableListProps) {
super(props);
......@@ -59,7 +68,13 @@ class TableList extends React.Component<TableListProps, TableListState> {
modalVisible: false,
isObjFinal: false,
isShowColumn: false,
columnSelected: COLUMN
isShowCompareModal: false,
columnSelected: COLUMN,
selectRows: [],
selectedRowKeys: [], // close selected trial message after modal closed
intermediateData: [],
intermediateId: '',
intermediateOtherKeys: []
};
}
......@@ -71,7 +86,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
.then(res => {
if (res.status === 200) {
const intermediateArr: number[] = [];
// support intermediate result is dict
// support intermediate result is dict because the last intermediate result is
// final result in a succeed trial, it may be a dict.
// get intermediate result dict keys array
let otherkeys: Array<string> = ['default'];
if (res.data.length !== 0) {
otherkeys = Object.keys(JSON.parse(res.data[0].data));
}
// intermediateArr just store default val
Object.keys(res.data).map(item => {
const temp = JSON.parse(res.data[item].data);
if (typeof temp === 'object') {
......@@ -83,7 +105,10 @@ class TableList extends React.Component<TableListProps, TableListState> {
const intermediate = intermediateGraphOption(intermediateArr, id);
if (this._isMounted) {
this.setState(() => ({
intermediateOption: intermediate
intermediateData: res.data, // store origin intermediate data for a trial
intermediateOption: intermediate,
intermediateOtherKeys: otherkeys,
intermediateId: id
}));
}
}
......@@ -95,6 +120,38 @@ class TableList extends React.Component<TableListProps, TableListState> {
}
}
selectOtherKeys = (value: string) => {
const isShowDefault: boolean = value === 'default' ? true : false;
const { intermediateData, intermediateId } = this.state;
const intermediateArr: number[] = [];
// just watch default key-val
if (isShowDefault === true) {
Object.keys(intermediateData).map(item => {
const temp = JSON.parse(intermediateData[item].data);
if (typeof temp === 'object') {
intermediateArr.push(temp[value]);
} else {
intermediateArr.push(temp);
}
});
} else {
Object.keys(intermediateData).map(item => {
const temp = JSON.parse(intermediateData[item].data);
if (typeof temp === 'object') {
intermediateArr.push(temp[value]);
}
});
}
const intermediate = intermediateGraphOption(intermediateArr, intermediateId);
// re-render
if (this._isMounted) {
this.setState(() => ({
intermediateOption: intermediate
}));
}
}
hideIntermediateModal = () => {
if (this._isMounted) {
this.setState({
......@@ -184,6 +241,31 @@ class TableList extends React.Component<TableListProps, TableListState> {
);
}
fillSelectedRowsTostate = (selected: number[] | string[], selectedRows: Array<TableObj>) => {
if (this._isMounted === true) {
this.setState(() => ({ selectRows: selectedRows, selectedRowKeys: selected }));
}
}
// open Compare-modal
compareBtn = () => {
const { selectRows } = this.state;
if (selectRows.length === 0) {
alert('Please select datas you want to compare!');
} else {
if (this._isMounted === true) {
this.setState({ isShowCompareModal: true });
}
}
}
// close Compare-modal
hideCompareModal = () => {
// close modal. clear select rows data, clear selected track
if (this._isMounted) {
this.setState({ isShowCompareModal: false, selectedRowKeys: [], selectRows: [] });
}
}
componentDidMount() {
this._isMounted = true;
}
......@@ -195,7 +277,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
render() {
const { entries, tableSource, updateList } = this.props;
const { intermediateOption, modalVisible, isShowColumn, columnSelected } = this.state;
const { intermediateOption, modalVisible, isShowColumn, columnSelected,
selectRows, isShowCompareModal, selectedRowKeys, intermediateOtherKeys } = this.state;
const rowSelection = {
selectedRowKeys: selectedRowKeys,
onChange: (selected: string[] | number[], selectedRows: Array<TableObj>) => {
this.fillSelectedRowsTostate(selected, selectedRows);
}
};
let showTitle = COLUMN;
let bgColor = '';
const trialJob: Array<TrialJob> = [];
......@@ -417,7 +506,9 @@ class TableList extends React.Component<TableListProps, TableListState> {
<Row className="tableList">
<div id="tableList">
<Table
ref={(table: Table<TableObj> | null) => this.tables = table}
columns={showColumn}
rowSelection={rowSelection}
expandedRowRender={this.openRow}
dataSource={tableSource}
className="commonTableStyle"
......@@ -432,6 +523,27 @@ class TableList extends React.Component<TableListProps, TableListState> {
destroyOnClose={true}
width="80%"
>
{
intermediateOtherKeys.length > 1
?
<Row className="selectKeys">
<Select
className="select"
defaultValue="default"
onSelect={this.selectOtherKeys}
>
{
Object.keys(intermediateOtherKeys).map(item => {
const keys = intermediateOtherKeys[item];
return <Option value={keys} key={item}>{keys}</Option>;
})
}
</Select>
</Row>
:
<div />
}
<ReactEcharts
option={intermediateOption}
style={{
......@@ -458,6 +570,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
className="titleColumn"
/>
</Modal>
<Compare compareRows={selectRows} visible={isShowCompareModal} cancelFunc={this.hideCompareModal} />
</Row>
);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment