Unverified Commit 611a45fc authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #19 from microsoft/master

pull code
parents 841d4677 e267a737
...@@ -26,7 +26,7 @@ import random ...@@ -26,7 +26,7 @@ import random
import numpy as np import numpy as np
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index, randint_to_quniform
import nni.parameter_expressions as parameter_expressions import nni.parameter_expressions as parameter_expressions
...@@ -175,6 +175,7 @@ class EvolutionTuner(Tuner): ...@@ -175,6 +175,7 @@ class EvolutionTuner(Tuner):
search_space : dict search_space : dict
""" """
self.searchspace_json = search_space self.searchspace_json = search_space
randint_to_quniform(self.searchspace_json)
self.space = json2space(self.searchspace_json) self.space = json2space(self.searchspace_json)
self.random_state = np.random.RandomState() self.random_state = np.random.RandomState()
......
...@@ -31,7 +31,7 @@ import json_tricks ...@@ -31,7 +31,7 @@ import json_tricks
from nni.protocol import CommandType, send from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.common import init_logger from nni.common import init_logger
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, randint_to_quniform
import nni.parameter_expressions as parameter_expressions import nni.parameter_expressions as parameter_expressions
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -357,6 +357,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -357,6 +357,7 @@ class Hyperband(MsgDispatcherBase):
number of trial jobs number of trial jobs
""" """
self.searchspace_json = data self.searchspace_json = data
randint_to_quniform(self.searchspace_json)
self.random_state = np.random.RandomState() self.random_state = np.random.RandomState()
def handle_trial_end(self, data): def handle_trial_end(self, data):
......
...@@ -27,7 +27,7 @@ import logging ...@@ -27,7 +27,7 @@ import logging
import hyperopt as hp import hyperopt as hp
import numpy as np import numpy as np
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index, randint_to_quniform
logger = logging.getLogger('hyperopt_AutoML') logger = logging.getLogger('hyperopt_AutoML')
...@@ -153,14 +153,14 @@ def _add_index(in_x, parameter): ...@@ -153,14 +153,14 @@ def _add_index(in_x, parameter):
Will change to format in hyperopt, like: Will change to format in hyperopt, like:
{'dropout_rate': 0.8, 'conv_size': {'_index': 1, '_value': 3}, 'hidden_size': {'_index': 1, '_value': 512}} {'dropout_rate': 0.8, 'conv_size': {'_index': 1, '_value': 3}, 'hidden_size': {'_index': 1, '_value': 512}}
""" """
if TYPE not in in_x: # if at the top level if NodeType.TYPE not in in_x: # if at the top level
out_y = dict() out_y = dict()
for key, value in parameter.items(): for key, value in parameter.items():
out_y[key] = _add_index(in_x[key], value) out_y[key] = _add_index(in_x[key], value)
return out_y return out_y
elif isinstance(in_x, dict): elif isinstance(in_x, dict):
value_type = in_x[TYPE] value_type = in_x[NodeType.TYPE]
value_format = in_x[VALUE] value_format = in_x[NodeType.VALUE]
if value_type == "choice": if value_type == "choice":
choice_name = parameter[0] if isinstance(parameter, choice_name = parameter[0] if isinstance(parameter,
list) else parameter list) else parameter
...@@ -173,15 +173,14 @@ def _add_index(in_x, parameter): ...@@ -173,15 +173,14 @@ def _add_index(in_x, parameter):
choice_value_format = item[1] choice_value_format = item[1]
if choice_key == choice_name: if choice_key == choice_name:
return { return {
INDEX: NodeType.INDEX: pos,
pos, NodeType.VALUE: [
VALUE: [
choice_name, choice_name,
_add_index(choice_value_format, parameter[1]) _add_index(choice_value_format, parameter[1])
] ]
} }
elif choice_name == item: elif choice_name == item:
return {INDEX: pos, VALUE: item} return {NodeType.INDEX: pos, NodeType.VALUE: item}
else: else:
return parameter return parameter
...@@ -232,6 +231,8 @@ class HyperoptTuner(Tuner): ...@@ -232,6 +231,8 @@ class HyperoptTuner(Tuner):
search_space : dict search_space : dict
""" """
self.json = search_space self.json = search_space
randint_to_quniform(self.json)
search_space_instance = json2space(self.json) search_space_instance = json2space(self.json)
rstate = np.random.RandomState() rstate = np.random.RandomState()
trials = hp.Trials() trials = hp.Trials()
......
...@@ -133,7 +133,7 @@ class MetisTuner(Tuner): ...@@ -133,7 +133,7 @@ class MetisTuner(Tuner):
self.x_bounds[idx] = bounds self.x_bounds[idx] = bounds
self.x_types[idx] = 'discrete_int' self.x_types[idx] = 'discrete_int'
elif key_type == 'randint': elif key_type == 'randint':
self.x_bounds[idx] = [0, key_range[0]] self.x_bounds[idx] = [key_range[0], key_range[1]]
self.x_types[idx] = 'range_int' self.x_types[idx] = 'range_int'
elif key_type == 'uniform': elif key_type == 'uniform':
self.x_bounds[idx] = [key_range[0], key_range[1]] self.x_bounds[idx] = [key_range[0], key_range[1]]
......
...@@ -21,21 +21,24 @@ ...@@ -21,21 +21,24 @@
smac_tuner.py smac_tuner.py
""" """
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
import sys import sys
import logging import logging
import numpy as np import numpy as np
import json_tricks
from enum import Enum, unique from nni.tuner import Tuner
from .convert_ss_to_scenario import generate_scenario from nni.utils import OptimizeMode, extract_scalar_reward
from smac.utils.io.cmd_reader import CMDReader from smac.utils.io.cmd_reader import CMDReader
from smac.scenario.scenario import Scenario from smac.scenario.scenario import Scenario
from smac.facade.smac_facade import SMAC from smac.facade.smac_facade import SMAC
from smac.facade.roar_facade import ROAR from smac.facade.roar_facade import ROAR
from smac.facade.epils_facade import EPILS from smac.facade.epils_facade import EPILS
from ConfigSpaceNNI import Configuration
from .convert_ss_to_scenario import generate_scenario
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward, randint_to_quniform
class SMACTuner(Tuner): class SMACTuner(Tuner):
...@@ -57,6 +60,7 @@ class SMACTuner(Tuner): ...@@ -57,6 +60,7 @@ class SMACTuner(Tuner):
self.update_ss_done = False self.update_ss_done = False
self.loguniform_key = set() self.loguniform_key = set()
self.categorical_dict = {} self.categorical_dict = {}
self.cs = None
def _main_cli(self): def _main_cli(self):
"""Main function of SMAC for CLI interface """Main function of SMAC for CLI interface
...@@ -66,7 +70,7 @@ class SMACTuner(Tuner): ...@@ -66,7 +70,7 @@ class SMACTuner(Tuner):
instance instance
optimizer optimizer
""" """
self.logger.info("SMAC call: %s" % (" ".join(sys.argv))) self.logger.info("SMAC call: %s", " ".join(sys.argv))
cmd_reader = CMDReader() cmd_reader = CMDReader()
args, _ = cmd_reader.read_cmd() args, _ = cmd_reader.read_cmd()
...@@ -95,6 +99,7 @@ class SMACTuner(Tuner): ...@@ -95,6 +99,7 @@ class SMACTuner(Tuner):
# Create scenario-object # Create scenario-object
scen = Scenario(args.scenario_file, []) scen = Scenario(args.scenario_file, [])
self.cs = scen.cs
if args.mode == "SMAC": if args.mode == "SMAC":
optimizer = SMAC( optimizer = SMAC(
...@@ -134,6 +139,7 @@ class SMACTuner(Tuner): ...@@ -134,6 +139,7 @@ class SMACTuner(Tuner):
search_space: search_space:
search space search space
""" """
randint_to_quniform(search_space)
if not self.update_ss_done: if not self.update_ss_done:
self.categorical_dict = generate_scenario(search_space) self.categorical_dict = generate_scenario(search_space)
if self.categorical_dict is None: if self.categorical_dict is None:
...@@ -258,4 +264,45 @@ class SMACTuner(Tuner): ...@@ -258,4 +264,45 @@ class SMACTuner(Tuner):
return params return params
def import_data(self, data): def import_data(self, data):
pass """Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
_completed_num = 0
for trial_info in data:
self.logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
# simply validate data format
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
self.logger.info("Useless trial data, value is %s, skip this trial data.", _value)
continue
# convert the keys in loguniform and categorical types
valid_entry = True
for key, value in _params.items():
if key in self.loguniform_key:
_params[key] = np.log(value)
elif key in self.categorical_dict:
if value in self.categorical_dict[key]:
_params[key] = self.categorical_dict[key].index(value)
else:
self.logger.info("The value %s of key %s is not in search space.", str(value), key)
valid_entry = False
break
if not valid_entry:
continue
# start import this data entry
_completed_num += 1
config = Configuration(self.cs, values=_params)
if self.optimize_mode is OptimizeMode.Maximize:
_value = -_value
if self.first_one:
self.smbo_solver.nni_smac_receive_first_run(config, _value)
self.first_one = False
else:
self.smbo_solver.nni_smac_receive_runs(config, _value)
self.logger.info("Successfully import data to smac tuner, total data: %d, imported data: %d.", len(data), _completed_num)
...@@ -36,7 +36,8 @@ __all__ = [ ...@@ -36,7 +36,8 @@ __all__ = [
'qnormal', 'qnormal',
'lognormal', 'lognormal',
'qlognormal', 'qlognormal',
'function_choice' 'function_choice',
'mutable_layer'
] ]
...@@ -78,6 +79,9 @@ if trial_env_vars.NNI_PLATFORM is None: ...@@ -78,6 +79,9 @@ if trial_env_vars.NNI_PLATFORM is None:
def function_choice(*funcs, name=None): def function_choice(*funcs, name=None):
return random.choice(funcs)() return random.choice(funcs)()
def mutable_layer():
raise RuntimeError('Cannot call nni.mutable_layer in this mode')
else: else:
def choice(options, name=None, key=None): def choice(options, name=None, key=None):
...@@ -113,6 +117,42 @@ else: ...@@ -113,6 +117,42 @@ else:
def function_choice(funcs, name=None, key=None): def function_choice(funcs, name=None, key=None):
return funcs[_get_param(key)]() return funcs[_get_param(key)]()
def mutable_layer(
mutable_id,
mutable_layer_id,
funcs,
funcs_args,
fixed_inputs,
optional_inputs,
optional_input_size=0):
'''execute the chosen function and inputs.
Below is an example of chosen function and inputs:
{
"mutable_id": {
"mutable_layer_id": {
"chosen_layer": "pool",
"chosen_inputs": ["out1", "out3"]
}
}
}
Parameters:
---------------
mutable_id: the name of this mutable_layer block (which could have multiple mutable layers)
mutable_layer_id: the name of a mutable layer in this block
funcs: dict of function calls
funcs_args:
fixed_inputs:
optional_inputs: dict of optional inputs
optional_input_size: number of candidate inputs to be chosen
'''
mutable_block = _get_param(mutable_id)
chosen_layer = mutable_block[mutable_layer_id]["chosen_layer"]
chosen_inputs = mutable_block[mutable_layer_id]["chosen_inputs"]
real_chosen_inputs = [optional_inputs[input_name] for input_name in chosen_inputs]
layer_out = funcs[chosen_layer]([fixed_inputs, real_chosen_inputs], *funcs_args[chosen_layer])
return layer_out
def _get_param(key): def _get_param(key):
if trial._params is None: if trial._params is None:
trial.get_next_parameter() trial.get_next_parameter()
......
...@@ -40,6 +40,7 @@ class OptimizeMode(Enum): ...@@ -40,6 +40,7 @@ class OptimizeMode(Enum):
Minimize = 'minimize' Minimize = 'minimize'
Maximize = 'maximize' Maximize = 'maximize'
class NodeType: class NodeType:
"""Node Type class """Node Type class
""" """
...@@ -83,6 +84,7 @@ def extract_scalar_reward(value, scalar_key='default'): ...@@ -83,6 +84,7 @@ def extract_scalar_reward(value, scalar_key='default'):
raise RuntimeError('Incorrect final result: the final result should be float/int, or a dict which has a key named "default" whose value is float/int.') raise RuntimeError('Incorrect final result: the final result should be float/int, or a dict which has a key named "default" whose value is float/int.')
return reward return reward
def convert_dict2tuple(value): def convert_dict2tuple(value):
""" """
convert dict type to tuple to solve unhashable problem. convert dict type to tuple to solve unhashable problem.
...@@ -94,9 +96,30 @@ def convert_dict2tuple(value): ...@@ -94,9 +96,30 @@ def convert_dict2tuple(value):
else: else:
return value return value
def init_dispatcher_logger(): def init_dispatcher_logger():
""" Initialize dispatcher logging configuration""" """ Initialize dispatcher logging configuration"""
logger_file_path = 'dispatcher.log' logger_file_path = 'dispatcher.log'
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None: if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path) logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
init_logger(logger_file_path, dispatcher_env_vars.NNI_LOG_LEVEL) init_logger(logger_file_path, dispatcher_env_vars.NNI_LOG_LEVEL)
def randint_to_quniform(in_x):
if isinstance(in_x, dict):
if NodeType.TYPE in in_x.keys():
if in_x[NodeType.TYPE] == 'randint':
value = in_x[NodeType.VALUE]
value.append(1)
in_x[NodeType.TYPE] = 'quniform'
in_x[NodeType.VALUE] = value
elif in_x[NodeType.TYPE] == 'choice':
randint_to_quniform(in_x[NodeType.VALUE])
else:
for key in in_x.keys():
randint_to_quniform(in_x[key])
elif isinstance(in_x, list):
for temp in in_x:
randint_to_quniform(temp)
...@@ -192,17 +192,21 @@ class Overview extends React.Component<{}, OverviewState> { ...@@ -192,17 +192,21 @@ class Overview extends React.Component<{}, OverviewState> {
method: 'GET' method: 'GET'
}) })
.then(res => { .then(res => {
if (res.status === 200 && this._isMounted) { if (res.status === 200) {
const errors = res.data.errors; const errors = res.data.errors;
if (errors.length !== 0) { if (errors.length !== 0) {
this.setState({ if (this._isMounted) {
status: res.data.status, this.setState({
errorStr: res.data.errors[0] status: res.data.status,
}); errorStr: res.data.errors[0]
});
}
} else { } else {
this.setState({ if (this._isMounted) {
status: res.data.status, this.setState({
}); status: res.data.status,
});
}
} }
} }
}); });
...@@ -254,7 +258,8 @@ class Overview extends React.Component<{}, OverviewState> { ...@@ -254,7 +258,8 @@ class Overview extends React.Component<{}, OverviewState> {
case 'SUCCEEDED': case 'SUCCEEDED':
profile.succTrial += 1; profile.succTrial += 1;
const desJobDetail: Parameters = { const desJobDetail: Parameters = {
parameters: {} parameters: {},
intermediate: []
}; };
const duration = (tableData[item].endTime - tableData[item].startTime) / 1000; const duration = (tableData[item].endTime - tableData[item].startTime) / 1000;
const acc = getFinal(tableData[item].finalMetricData); const acc = getFinal(tableData[item].finalMetricData);
......
...@@ -27,6 +27,11 @@ interface TrialDetailState { ...@@ -27,6 +27,11 @@ interface TrialDetailState {
entriesInSelect: string; entriesInSelect: string;
searchSpace: string; searchSpace: string;
isMultiPhase: boolean; isMultiPhase: boolean;
isTableLoading: boolean;
whichGraph: string;
hyperCounts: number; // user click the hyper-parameter counts
durationCounts: number;
intermediateCounts: number;
} }
class TrialsDetail extends React.Component<{}, TrialDetailState> { class TrialsDetail extends React.Component<{}, TrialDetailState> {
...@@ -70,9 +75,14 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -70,9 +75,14 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
experimentLogCollection: false, experimentLogCollection: false,
entriesTable: 20, entriesTable: 20,
entriesInSelect: '20', entriesInSelect: '20',
isHasSearch: false,
searchSpace: '', searchSpace: '',
isMultiPhase: false whichGraph: '1',
isHasSearch: false,
isMultiPhase: false,
isTableLoading: false,
hyperCounts: 0,
durationCounts: 0,
intermediateCounts: 0
}; };
} }
...@@ -85,6 +95,9 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -85,6 +95,9 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
]) ])
.then(axios.spread((res, res1) => { .then(axios.spread((res, res1) => {
if (res.status === 200 && res1.status === 200) { if (res.status === 200 && res1.status === 200) {
if (this._isMounted === true) {
this.setState(() => ({ isTableLoading: true }));
}
const trialJobs = res.data; const trialJobs = res.data;
const metricSource = res1.data; const metricSource = res1.data;
const trialTable: Array<TableObj> = []; const trialTable: Array<TableObj> = [];
...@@ -175,6 +188,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -175,6 +188,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
} }
if (this._isMounted) { if (this._isMounted) {
this.setState(() => ({ this.setState(() => ({
isTableLoading: false,
tableListSource: trialTable tableListSource: trialTable
})); }));
} }
...@@ -239,26 +253,26 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -239,26 +253,26 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
} }
handleEntriesSelect = (value: string) => { handleEntriesSelect = (value: string) => {
switch (value) { // user select isn't 'all'
case '20': if (value !== 'all') {
this.setState(() => ({ entriesTable: 20 })); if (this._isMounted) {
break; this.setState(() => ({ entriesTable: parseInt(value, 10) }));
case '50': }
this.setState(() => ({ entriesTable: 50 })); } else {
break; const { tableListSource } = this.state;
case '100': if (this._isMounted) {
this.setState(() => ({ entriesTable: 100 })); this.setState(() => ({
break; entriesInSelect: 'all',
case 'all': entriesTable: tableListSource.length
const { tableListSource } = this.state; }));
if (this._isMounted) { }
this.setState(() => ({ }
entriesInSelect: 'all', }
entriesTable: tableListSource.length
})); handleWhichTabs = (activeKey: string) => {
} // const which = JSON.parse(activeKey);
break; if (this._isMounted) {
default: this.setState(() => ({ whichGraph: activeKey }));
} }
} }
...@@ -315,18 +329,21 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -315,18 +329,21 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
const { const {
tableListSource, searchResultSource, isHasSearch, isMultiPhase, tableListSource, searchResultSource, isHasSearch, isMultiPhase,
entriesTable, experimentPlatform, searchSpace, experimentLogCollection entriesTable, experimentPlatform, searchSpace, experimentLogCollection,
whichGraph, isTableLoading
} = this.state; } = this.state;
const source = isHasSearch ? searchResultSource : tableListSource; const source = isHasSearch ? searchResultSource : tableListSource;
return ( return (
<div> <div>
<div className="trial" id="tabsty"> <div className="trial" id="tabsty">
<Tabs type="card"> <Tabs type="card" onChange={this.handleWhichTabs}>
{/* <TabPane tab={this.titleOfacc} key="1" destroyInactiveTabPane={true}> */}
<TabPane tab={this.titleOfacc} key="1"> <TabPane tab={this.titleOfacc} key="1">
<Row className="graph"> <Row className="graph">
<DefaultPoint <DefaultPoint
height={432} height={432}
showSource={source} showSource={source}
whichGraph={whichGraph}
/> />
</Row> </Row>
</TabPane> </TabPane>
...@@ -335,14 +352,16 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -335,14 +352,16 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
<Para <Para
dataSource={source} dataSource={source}
expSearchSpace={searchSpace} expSearchSpace={searchSpace}
whichGraph={whichGraph}
/> />
</Row> </Row>
</TabPane> </TabPane>
<TabPane tab={this.titleOfDuration} key="3"> <TabPane tab={this.titleOfDuration} key="3">
<Duration source={source} /> <Duration source={source} whichGraph={whichGraph} />
{/* <Duration source={source} whichGraph={whichGraph} clickCounts={durationCounts} /> */}
</TabPane> </TabPane>
<TabPane tab={this.titleOfIntermediate} key="4"> <TabPane tab={this.titleOfIntermediate} key="4">
<Intermediate source={source} /> <Intermediate source={source} whichGraph={whichGraph} />
</TabPane> </TabPane>
</Tabs> </Tabs>
</div> </div>
...@@ -388,6 +407,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -388,6 +407,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
<TableList <TableList
entries={entriesTable} entries={entriesTable}
tableSource={source} tableSource={source}
isTableLoading={isTableLoading}
isMultiPhase={isMultiPhase} isMultiPhase={isMultiPhase}
platform={experimentPlatform} platform={experimentPlatform}
updateList={this.getDetailSource} updateList={this.getDetailSource}
......
...@@ -28,12 +28,11 @@ class IntermediateVal extends React.Component<IntermediateValProps, {}> { ...@@ -28,12 +28,11 @@ class IntermediateVal extends React.Component<IntermediateValProps, {}> {
if (wei > 6) { if (wei > 6) {
result = `${lastVal.toFixed(6)}`; result = `${lastVal.toFixed(6)}`;
} }
if (status === 'SUCCEEDED') { }
result = `${lastVal.toFixed(6)} (FINAL)`; if (status === 'SUCCEEDED') {
} else { result = `${result} (FINAL)`;
result = `${lastVal.toFixed(6)} (LATEST)`; } else {
} result = `${result} (LATEST)`;
} }
} else { } else {
result = '--'; result = '--';
......
...@@ -3,9 +3,10 @@ import * as copy from 'copy-to-clipboard'; ...@@ -3,9 +3,10 @@ import * as copy from 'copy-to-clipboard';
import PaiTrialLog from '../public-child/PaiTrialLog'; import PaiTrialLog from '../public-child/PaiTrialLog';
import TrialLog from '../public-child/TrialLog'; import TrialLog from '../public-child/TrialLog';
import { TableObj } from '../../static/interface'; import { TableObj } from '../../static/interface';
import { Row, Tabs, Button, message } from 'antd'; import { Row, Tabs, Button, message, Modal } from 'antd';
import { MANAGER_IP } from '../../static/const'; import { MANAGER_IP } from '../../static/const';
import '../../static/style/overview.scss'; import '../../static/style/overview.scss';
import '../../static/style/copyParameter.scss';
import JSONTree from 'react-json-tree'; import JSONTree from 'react-json-tree';
const TabPane = Tabs.TabPane; const TabPane = Tabs.TabPane;
...@@ -17,43 +18,62 @@ interface OpenRowProps { ...@@ -17,43 +18,62 @@ interface OpenRowProps {
} }
interface OpenRowState { interface OpenRowState {
idList: Array<string>; isShowFormatModal: boolean;
formatStr: string;
} }
class OpenRow extends React.Component<OpenRowProps, OpenRowState> { class OpenRow extends React.Component<OpenRowProps, OpenRowState> {
public _isMounted: boolean;
constructor(props: OpenRowProps) { constructor(props: OpenRowProps) {
super(props); super(props);
this.state = { this.state = {
idList: [''] isShowFormatModal: false,
formatStr: ''
}; };
}
showFormatModal = (record: TableObj) => {
// get copy parameters
const params = JSON.stringify(record.description.parameters, null, 4);
// open modal with format string
if (this._isMounted === true) {
this.setState(() => ({ isShowFormatModal: true, formatStr: params }));
}
}
hideFormatModal = () => {
// close modal, destroy state format string data
if (this._isMounted === true) {
this.setState(() => ({ isShowFormatModal: false, formatStr: '' }));
}
} }
copyParams = (record: TableObj) => { copyParams = () => {
// json format // json format
const params = JSON.stringify(record.description.parameters, null, 4); const { formatStr } = this.state;
if (copy(params)) { if (copy(formatStr)) {
message.destroy(); message.destroy();
message.success('Success copy parameters to clipboard in form of python dict !', 3); message.success('Success copy parameters to clipboard in form of python dict !', 3);
const { idList } = this.state;
const copyIdList: Array<string> = idList;
copyIdList[copyIdList.length - 1] = record.id;
this.setState(() => ({
idList: copyIdList
}));
} else { } else {
message.destroy(); message.destroy();
message.error('Failed !', 2); message.error('Failed !', 2);
} }
this.hideFormatModal();
} }
componentDidMount() {
this._isMounted = true;
}
componentWillUnmount() {
this._isMounted = false;
}
render() { render() {
const { trainingPlatform, record, logCollection, multiphase } = this.props; const { trainingPlatform, record, logCollection, multiphase } = this.props;
const { idList } = this.state; const { isShowFormatModal, formatStr } = this.state;
let isClick = false; let isClick = false;
let isHasParameters = true; let isHasParameters = true;
if (idList.indexOf(record.id) !== -1) { isClick = true; }
if (record.description.parameters.error) { if (record.description.parameters.error) {
isHasParameters = false; isHasParameters = false;
} }
...@@ -101,7 +121,7 @@ class OpenRow extends React.Component<OpenRowProps, OpenRowState> { ...@@ -101,7 +121,7 @@ class OpenRow extends React.Component<OpenRowProps, OpenRowState> {
</Row> </Row>
<Row className="copy"> <Row className="copy">
<Button <Button
onClick={this.copyParams.bind(this, record)} onClick={this.showFormatModal.bind(this, record)}
> >
Copy as python Copy as python
</Button> </Button>
...@@ -128,6 +148,21 @@ class OpenRow extends React.Component<OpenRowProps, OpenRowState> { ...@@ -128,6 +148,21 @@ class OpenRow extends React.Component<OpenRowProps, OpenRowState> {
} }
</TabPane> </TabPane>
</Tabs> </Tabs>
<Modal
title="Format"
okText="Copy"
centered={true}
visible={isShowFormatModal}
onCancel={this.hideFormatModal}
maskClosable={false} // click mongolian layer don't close modal
onOk={this.copyParams}
destroyOnClose={true}
width="60%"
className="format"
>
{/* write string in pre to show format string */}
<pre className="formatStr">{formatStr}</pre>
</Modal>
</Row > </Row >
); );
} }
......
import * as React from 'react'; import * as React from 'react';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
import { filterByStatus } from '../../static/function';
import { TableObj, DetailAccurPoint, TooltipForAccuracy } from '../../static/interface'; import { TableObj, DetailAccurPoint, TooltipForAccuracy } from '../../static/interface';
require('echarts/lib/chart/scatter'); require('echarts/lib/chart/scatter');
require('echarts/lib/component/tooltip'); require('echarts/lib/component/tooltip');
...@@ -8,11 +9,13 @@ require('echarts/lib/component/title'); ...@@ -8,11 +9,13 @@ require('echarts/lib/component/title');
interface DefaultPointProps { interface DefaultPointProps {
showSource: Array<TableObj>; showSource: Array<TableObj>;
height: number; height: number;
whichGraph: string;
} }
interface DefaultPointState { interface DefaultPointState {
defaultSource: object; defaultSource: object;
accNodata: string; accNodata: string;
succeedTrials: number;
} }
class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> { class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> {
...@@ -22,91 +25,130 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> ...@@ -22,91 +25,130 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
super(props); super(props);
this.state = { this.state = {
defaultSource: {}, defaultSource: {},
accNodata: 'No data' accNodata: '',
succeedTrials: 10000000
}; };
} }
defaultMetric = (showSource: Array<TableObj>) => { defaultMetric = (succeedSource: Array<TableObj>) => {
const accSource: Array<DetailAccurPoint> = []; const accSource: Array<DetailAccurPoint> = [];
Object.keys(showSource).map(item => { const showSource: Array<TableObj> = succeedSource.filter(filterByStatus);
const temp = showSource[item]; const lengthOfSource = showSource.length;
if (temp.status === 'SUCCEEDED' && temp.acc !== undefined) { const tooltipDefault = lengthOfSource === 0 ? 'No data' : '';
if (temp.acc.default !== undefined) { if (this._isMounted === true) {
const searchSpace = temp.description.parameters; this.setState(() => ({
accSource.push({ succeedTrials: lengthOfSource,
acc: temp.acc.default, accNodata: tooltipDefault
index: temp.sequenceId, }));
searchSpace: JSON.stringify(searchSpace) }
}); if (lengthOfSource === 0) {
const nullGraph = {
grid: {
left: '8%'
},
xAxis: {
name: 'Trial',
type: 'category',
},
yAxis: {
name: 'Default metric',
type: 'value',
} }
};
if (this._isMounted === true) {
this.setState(() => ({
defaultSource: nullGraph
}));
} }
}); } else {
const resultList: Array<number | string>[] = []; const resultList: Array<number | string>[] = [];
Object.keys(accSource).map(item => { Object.keys(showSource).map(item => {
const items = accSource[item]; const temp = showSource[item];
let temp: Array<number | string>; if (temp.acc !== undefined) {
temp = [items.index, items.acc, JSON.parse(items.searchSpace)]; if (temp.acc.default !== undefined) {
resultList.push(temp); const searchSpace = temp.description.parameters;
}); accSource.push({
acc: temp.acc.default,
const allAcuracy = { index: temp.sequenceId,
grid: { searchSpace: JSON.stringify(searchSpace)
left: '8%' });
},
tooltip: {
trigger: 'item',
enterable: true,
position: function (point: Array<number>, data: TooltipForAccuracy) {
if (data.data[0] < resultList.length / 2) {
return [point[0], 80];
} else {
return [point[0] - 300, 80];
} }
},
formatter: function (data: TooltipForAccuracy) {
const result = '<div class="tooldetailAccuracy">' +
'<div>Trial No.: ' + data.data[0] + '</div>' +
'<div>Default metric: ' + data.data[1] + '</div>' +
'<div>Parameters: ' +
'<pre>' + JSON.stringify(data.data[2], null, 4) + '</pre>' +
'</div>' +
'</div>';
return result;
}
},
xAxis: {
name: 'Trial',
type: 'category',
},
yAxis: {
name: 'Default metric',
type: 'value',
},
series: [{
symbolSize: 6,
type: 'scatter',
data: resultList
}]
};
if (this._isMounted === true) {
this.setState({ defaultSource: allAcuracy }, () => {
if (resultList.length === 0) {
this.setState({
accNodata: 'No data'
});
} else {
this.setState({
accNodata: ''
});
} }
}); });
Object.keys(accSource).map(item => {
const items = accSource[item];
let temp: Array<number | string>;
temp = [items.index, items.acc, JSON.parse(items.searchSpace)];
resultList.push(temp);
});
const allAcuracy = {
grid: {
left: '8%'
},
tooltip: {
trigger: 'item',
enterable: true,
position: function (point: Array<number>, data: TooltipForAccuracy) {
if (data.data[0] < resultList.length / 2) {
return [point[0], 80];
} else {
return [point[0] - 300, 80];
}
},
formatter: function (data: TooltipForAccuracy) {
const result = '<div class="tooldetailAccuracy">' +
'<div>Trial No.: ' + data.data[0] + '</div>' +
'<div>Default metric: ' + data.data[1] + '</div>' +
'<div>Parameters: ' +
'<pre>' + JSON.stringify(data.data[2], null, 4) + '</pre>' +
'</div>' +
'</div>';
return result;
}
},
xAxis: {
name: 'Trial',
type: 'category',
},
yAxis: {
name: 'Default metric',
type: 'value',
},
series: [{
symbolSize: 6,
type: 'scatter',
data: resultList
}]
};
if (this._isMounted === true) {
this.setState(() => ({
defaultSource: allAcuracy
}));
}
} }
} }
// update parent component state // update parent component state
componentWillReceiveProps(nextProps: DefaultPointProps) { componentWillReceiveProps(nextProps: DefaultPointProps) {
const showSource = nextProps.showSource;
this.defaultMetric(showSource); const { whichGraph, showSource } = nextProps;
if (whichGraph === '1') {
this.defaultMetric(showSource);
}
}
shouldComponentUpdate(nextProps: DefaultPointProps, nextState: DefaultPointState) {
const { whichGraph } = nextProps;
const succTrial = this.state.succeedTrials;
const { succeedTrials } = nextState;
if (whichGraph === '1') {
if (succeedTrials !== succTrial) {
return true;
}
}
// only whichGraph !== '1', default metric can't update
return false;
} }
componentDidMount() { componentDidMount() {
...@@ -116,7 +158,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> ...@@ -116,7 +158,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
componentWillUnmount() { componentWillUnmount() {
this._isMounted = false; this._isMounted = false;
} }
render() { render() {
const { height } = this.props; const { height } = this.props;
const { defaultSource, accNodata } = this.state; const { defaultSource, accNodata } = this.state;
...@@ -131,6 +173,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> ...@@ -131,6 +173,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}} }}
theme="my_theme" theme="my_theme"
notMerge={true} // update now notMerge={true} // update now
// lazyUpdate={true}
/> />
<div className="showMess">{accNodata}</div> <div className="showMess">{accNodata}</div>
</div> </div>
......
import * as React from 'react'; import * as React from 'react';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
import { TableObj } from 'src/static/interface'; import { TableObj } from 'src/static/interface';
import { filterDuration } from 'src/static/function';
require('echarts/lib/chart/bar'); require('echarts/lib/chart/bar');
require('echarts/lib/component/tooltip'); require('echarts/lib/component/tooltip');
require('echarts/lib/component/title'); require('echarts/lib/component/title');
...@@ -12,6 +13,7 @@ interface Runtrial { ...@@ -12,6 +13,7 @@ interface Runtrial {
interface DurationProps { interface DurationProps {
source: Array<TableObj>; source: Array<TableObj>;
whichGraph: string;
} }
interface DurationState { interface DurationState {
...@@ -26,13 +28,64 @@ class Duration extends React.Component<DurationProps, DurationState> { ...@@ -26,13 +28,64 @@ class Duration extends React.Component<DurationProps, DurationState> {
super(props); super(props);
this.state = { this.state = {
durationSource: {} durationSource: this.initDuration(this.props.source),
}; };
} }
initDuration = (source: Array<TableObj>) => {
const trialId: Array<string> = [];
const trialTime: Array<number> = [];
const trialJobs = source.filter(filterDuration);
Object.keys(trialJobs).map(item => {
const temp = trialJobs[item];
trialId.push(temp.sequenceId);
trialTime.push(temp.duration);
});
return {
tooltip: {
trigger: 'axis',
axisPointer: {
type: 'shadow'
}
},
grid: {
bottom: '3%',
containLabel: true,
left: '1%',
right: '4%'
},
dataZoom: [{
type: 'slider',
name: 'trial',
filterMode: 'filter',
yAxisIndex: 0,
orient: 'vertical'
}, {
type: 'slider',
name: 'trial',
filterMode: 'filter',
xAxisIndex: 0
}],
xAxis: {
name: 'Time',
type: 'value',
},
yAxis: {
name: 'Trial',
type: 'category',
data: trialId
},
series: [{
type: 'bar',
data: trialTime
}]
};
}
getOption = (dataObj: Runtrial) => { getOption = (dataObj: Runtrial) => {
return { return {
tooltip: { tooltip: {
trigger: 'axis', trigger: 'axis',
axisPointer: { axisPointer: {
...@@ -45,7 +98,7 @@ class Duration extends React.Component<DurationProps, DurationState> { ...@@ -45,7 +98,7 @@ class Duration extends React.Component<DurationProps, DurationState> {
left: '1%', left: '1%',
right: '4%' right: '4%'
}, },
dataZoom: [{ dataZoom: [{
type: 'slider', type: 'slider',
name: 'trial', name: 'trial',
...@@ -74,17 +127,16 @@ class Duration extends React.Component<DurationProps, DurationState> { ...@@ -74,17 +127,16 @@ class Duration extends React.Component<DurationProps, DurationState> {
}; };
} }
drawDurationGraph = (trialJobs: Array<TableObj>) => { drawDurationGraph = (source: Array<TableObj>) => {
// why this function run two times when props changed?
const trialId: Array<string> = []; const trialId: Array<string> = [];
const trialTime: Array<number> = []; const trialTime: Array<number> = [];
const trialRun: Array<Runtrial> = []; const trialRun: Array<Runtrial> = [];
const trialJobs = source.filter(filterDuration);
Object.keys(trialJobs).map(item => { Object.keys(trialJobs).map(item => {
const temp = trialJobs[item]; const temp = trialJobs[item];
if (temp.status !== 'WAITING') { trialId.push(temp.sequenceId);
trialId.push(temp.sequenceId); trialTime.push(temp.duration);
trialTime.push(temp.duration);
}
}); });
trialRun.push({ trialRun.push({
trialId: trialId, trialId: trialId,
...@@ -97,18 +149,43 @@ class Duration extends React.Component<DurationProps, DurationState> { ...@@ -97,18 +149,43 @@ class Duration extends React.Component<DurationProps, DurationState> {
} }
} }
componentWillReceiveProps(nextProps: DurationProps) {
const trialJobs = nextProps.source;
this.drawDurationGraph(trialJobs);
}
componentDidMount() { componentDidMount() {
this._isMounted = true; this._isMounted = true;
// init: user don't search const { source } = this.props;
const {source} = this.props;
this.drawDurationGraph(source); this.drawDurationGraph(source);
} }
componentWillReceiveProps(nextProps: DurationProps) {
const { whichGraph, source } = nextProps;
if (whichGraph === '3') {
this.drawDurationGraph(source);
}
}
shouldComponentUpdate(nextProps: DurationProps, nextState: DurationState) {
const { whichGraph, source } = nextProps;
if (whichGraph === '3') {
const beforeSource = this.props.source;
if (whichGraph !== this.props.whichGraph) {
return true;
}
if (source.length !== beforeSource.length) {
return true;
}
if (source[source.length - 1].duration !== beforeSource[beforeSource.length - 1].duration) {
return true;
}
if (source[source.length - 1].status !== beforeSource[beforeSource.length - 1].status) {
return true;
}
}
return false;
}
componentWillUnmount() { componentWillUnmount() {
this._isMounted = false; this._isMounted = false;
} }
...@@ -121,6 +198,7 @@ class Duration extends React.Component<DurationProps, DurationState> { ...@@ -121,6 +198,7 @@ class Duration extends React.Component<DurationProps, DurationState> {
option={durationSource} option={durationSource}
style={{ width: '95%', height: 412, margin: '0 auto' }} style={{ width: '95%', height: 412, margin: '0 auto' }}
theme="my_theme" theme="my_theme"
notMerge={true} // update now
/> />
</div> </div>
); );
......
...@@ -11,16 +11,21 @@ interface Intermedia { ...@@ -11,16 +11,21 @@ interface Intermedia {
data: Array<number | object>; // intermediate data data: Array<number | object>; // intermediate data
hyperPara: object; // each trial hyperpara value hyperPara: object; // each trial hyperpara value
} }
interface IntermediateState { interface IntermediateState {
detailSource: Array<TableObj>;
interSource: object; interSource: object;
filterSource: Array<TableObj>; filterSource: Array<TableObj>;
eachIntermediateNum: number; // trial's intermediate number count eachIntermediateNum: number; // trial's intermediate number count
isLoadconfirmBtn: boolean; isLoadconfirmBtn: boolean;
isFilter: boolean; isFilter: boolean;
length: number;
clickCounts: number; // user filter intermediate click confirm btn's counts
} }
interface IntermediateProps { interface IntermediateProps {
source: Array<TableObj>; source: Array<TableObj>;
whichGraph: string;
} }
class Intermediate extends React.Component<IntermediateProps, IntermediateState> { class Intermediate extends React.Component<IntermediateProps, IntermediateState> {
...@@ -34,39 +39,25 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState> ...@@ -34,39 +39,25 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
constructor(props: IntermediateProps) { constructor(props: IntermediateProps) {
super(props); super(props);
this.state = { this.state = {
detailSource: [],
interSource: {}, interSource: {},
filterSource: [], filterSource: [],
eachIntermediateNum: 1, eachIntermediateNum: 1,
isLoadconfirmBtn: false, isLoadconfirmBtn: false,
isFilter: false isFilter: false,
}; length: 100000,
} clickCounts: 0
initMediate = () => {
const option = {
grid: {
left: '5%',
top: 40,
containLabel: true
},
xAxis: {
type: 'category',
boundaryGap: false,
},
yAxis: {
type: 'value',
name: 'Scape'
}
}; };
if (this._isMounted) {
this.setState(() => ({
interSource: option
}));
}
} }
drawIntermediate = (source: Array<TableObj>) => { drawIntermediate = (source: Array<TableObj>) => {
if (source.length > 0) { if (source.length > 0) {
if (this._isMounted) {
this.setState(() => ({
length: source.length,
detailSource: source
}));
}
const trialIntermediate: Array<Intermedia> = []; const trialIntermediate: Array<Intermedia> = [];
Object.keys(source).map(item => { Object.keys(source).map(item => {
const temp = source[item]; const temp = source[item];
...@@ -140,7 +131,24 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState> ...@@ -140,7 +131,24 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
})); }));
} }
} else { } else {
this.initMediate(); const nullData = {
grid: {
left: '5%',
top: 40,
containLabel: true
},
xAxis: {
type: 'category',
boundaryGap: false,
},
yAxis: {
type: 'value',
name: 'Scape'
}
};
if (this._isMounted) {
this.setState(() => ({ interSource: nullData }));
}
} }
} }
...@@ -183,8 +191,9 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState> ...@@ -183,8 +191,9 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
this.setState({ filterSource: filterSource }); this.setState({ filterSource: filterSource });
} }
this.drawIntermediate(filterSource); this.drawIntermediate(filterSource);
const counts = this.state.clickCounts + 1;
this.setState({ isLoadconfirmBtn: false, clickCounts: counts });
} }
this.setState({ isLoadconfirmBtn: false });
}); });
} }
} }
...@@ -204,28 +213,73 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState> ...@@ -204,28 +213,73 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
this.drawIntermediate(source); this.drawIntermediate(source);
} }
componentWillReceiveProps(nextProps: IntermediateProps) { componentWillReceiveProps(nextProps: IntermediateProps, nextState: IntermediateState) {
const { isFilter, filterSource } = this.state; const { isFilter, filterSource } = nextState;
if (isFilter === true) { const { whichGraph, source } = nextProps;
const pointVal = this.pointInput !== null ? this.pointInput.value : '';
const minVal = this.minValInput !== null ? this.minValInput.value : ''; if (whichGraph === '4') {
if (pointVal === '' && minVal === '') { if (isFilter === true) {
this.drawIntermediate(nextProps.source); const pointVal = this.pointInput !== null ? this.pointInput.value : '';
const minVal = this.minValInput !== null ? this.minValInput.value : '';
if (pointVal === '' && minVal === '') {
this.drawIntermediate(source);
} else {
this.drawIntermediate(filterSource);
}
} else { } else {
this.drawIntermediate(filterSource); this.drawIntermediate(source);
} }
} else {
this.drawIntermediate(nextProps.source);
} }
} }
shouldComponentUpdate(nextProps: IntermediateProps, nextState: IntermediateState) {
const { whichGraph } = nextProps;
const beforeGraph = this.props.whichGraph;
if (whichGraph === '4') {
const { source } = nextProps;
const { isFilter, length, clickCounts } = nextState;
const beforeLength = this.state.length;
const beforeSource = this.state.detailSource;
const beforeClickCounts = this.state.clickCounts;
if (isFilter !== this.state.isFilter) {
return true;
}
if (clickCounts !== beforeClickCounts) {
return true;
}
if (isFilter === false) {
if (whichGraph !== beforeGraph) {
return true;
}
if (length !== beforeLength) {
return true;
}
if (source[source.length - 1].description.intermediate.length !==
beforeSource[beforeSource.length - 1].description.intermediate.length) {
return true;
}
if (source[source.length - 1].duration !== beforeSource[beforeSource.length - 1].duration) {
return true;
}
if (source[source.length - 1].status !== beforeSource[beforeSource.length - 1].status) {
return true;
}
}
}
return false;
}
componentWillUnmount() { componentWillUnmount() {
this._isMounted = false; this._isMounted = false;
} }
render() { render() {
const { interSource, isLoadconfirmBtn, isFilter } = this.state; const { interSource, isLoadconfirmBtn, isFilter } = this.state;
return ( return (
<div> <div>
{/* style in para.scss */} {/* style in para.scss */}
......
import * as React from 'react'; import * as React from 'react';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
import { filterByStatus } from '../../static/function';
import { Row, Col, Select, Button, message } from 'antd'; import { Row, Col, Select, Button, message } from 'antd';
import { ParaObj, Dimobj, TableObj, SearchSpace } from '../../static/interface'; import { ParaObj, Dimobj, TableObj } from '../../static/interface';
const Option = Select.Option; const Option = Select.Option;
require('echarts/lib/chart/parallel'); require('echarts/lib/chart/parallel');
require('echarts/lib/component/tooltip'); require('echarts/lib/component/tooltip');
...@@ -11,6 +12,7 @@ require('../../static/style/para.scss'); ...@@ -11,6 +12,7 @@ require('../../static/style/para.scss');
require('../../static/style/button.scss'); require('../../static/style/button.scss');
interface ParaState { interface ParaState {
// paraSource: Array<TableObj>;
option: object; option: object;
paraBack: ParaObj; paraBack: ParaObj;
dimName: Array<string>; dimName: Array<string>;
...@@ -19,11 +21,15 @@ interface ParaState { ...@@ -19,11 +21,15 @@ interface ParaState {
paraNodata: string; paraNodata: string;
max: number; // graph color bar limit max: number; // graph color bar limit
min: number; min: number;
sutrialCount: number; // succeed trial numbers for SUC
clickCounts: number;
isLoadConfirm: boolean;
} }
interface ParaProps { interface ParaProps {
dataSource: Array<TableObj>; dataSource: Array<TableObj>;
expSearchSpace: string; expSearchSpace: string;
whichGraph: string;
} }
message.config({ message.config({
...@@ -45,6 +51,8 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -45,6 +51,8 @@ class Para extends React.Component<ParaProps, ParaState> {
constructor(props: ParaProps) { constructor(props: ParaProps) {
super(props); super(props);
this.state = { this.state = {
// paraSource: [],
// option: this.hyperParaPic,
option: {}, option: {},
dimName: [], dimName: [],
paraBack: { paraBack: {
...@@ -58,98 +66,20 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -58,98 +66,20 @@ class Para extends React.Component<ParaProps, ParaState> {
percent: 0, percent: 0,
paraNodata: '', paraNodata: '',
min: 0, min: 0,
max: 1 max: 1,
sutrialCount: 10000000,
clickCounts: 1,
isLoadConfirm: false
}; };
} }
componentDidMount() {
this._isMounted = true;
this.reInit();
}
getParallelAxis = getParallelAxis =
( (
dimName: Array<string>, searchRange: SearchSpace, dimName: Array<string>, parallelAxis: Array<Dimobj>,
accPara: Array<number>, accPara: Array<number>, eachTrialParams: Array<string>
eachTrialParams: Array<string>, paraYdata: number[][]
) => { ) => {
if (this._isMounted) {
this.setState(() => ({
dimName: dimName
}));
}
const parallelAxis: Array<Dimobj> = [];
// search space range and specific value [only number]
for (let i = 0; i < dimName.length; i++) {
const searchKey = searchRange[dimName[i]];
switch (searchKey._type) {
case 'uniform':
case 'quniform':
parallelAxis.push({
dim: i,
name: dimName[i],
max: searchKey._value[1],
min: searchKey._value[0]
});
break;
case 'randint':
parallelAxis.push({
dim: i,
name: dimName[i],
max: searchKey._value[0] - 1,
min: 0
});
break;
case 'choice':
const data: Array<string> = [];
for (let j = 0; j < searchKey._value.length; j++) {
data.push(searchKey._value[j].toString());
}
parallelAxis.push({
dim: i,
name: dimName[i],
type: 'category',
data: data,
boundaryGap: true,
axisLine: {
lineStyle: {
type: 'dotted', // axis type,solid,dashed,dotted
width: 1
}
},
axisTick: {
show: true,
interval: 0,
alignWithLabel: true,
},
axisLabel: {
show: true,
interval: 0,
// rotate: 30
},
});
break;
// support log distribute
case 'loguniform':
parallelAxis.push({
dim: i,
name: dimName[i],
type: 'log',
});
break;
default:
parallelAxis.push({
dim: i,
name: dimName[i]
});
}
}
// get data for every lines. if dim is choice type, number -> toString() // get data for every lines. if dim is choice type, number -> toString()
const paraYdata: number[][] = [];
Object.keys(eachTrialParams).map(item => { Object.keys(eachTrialParams).map(item => {
let temp: Array<number> = []; let temp: Array<number> = [];
for (let i = 0; i < dimName.length; i++) { for (let i = 0; i < dimName.length; i++) {
...@@ -169,7 +99,7 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -169,7 +99,7 @@ class Para extends React.Component<ParaProps, ParaState> {
Object.keys(paraYdata).map(item => { Object.keys(paraYdata).map(item => {
paraYdata[item].push(accPara[item]); paraYdata[item].push(accPara[item]);
}); });
// according acc to sort ydata // according acc to sort ydata // sort to find top percent dataset
if (paraYdata.length !== 0) { if (paraYdata.length !== 0) {
const len = paraYdata[0].length - 1; const len = paraYdata[0].length - 1;
paraYdata.sort((a, b) => b[len] - a[len]); paraYdata.sort((a, b) => b[len] - a[len]);
...@@ -191,28 +121,153 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -191,28 +121,153 @@ class Para extends React.Component<ParaProps, ParaState> {
this.swapGraph(paraData, swapAxisArr); this.swapGraph(paraData, swapAxisArr);
} }
this.getOption(paraData); this.getOption(paraData);
if (this._isMounted === true) {
this.setState(() => ({ paraBack: paraData }));
}
} }
hyperParaPic = (dataSource: Array<TableObj>, searchSpace: string) => { hyperParaPic = (source: Array<TableObj>, searchSpace: string) => {
// filter succeed trials [{}, {}, {}]
const dataSource: Array<TableObj> = source.filter(filterByStatus);
const lenOfDataSource: number = dataSource.length;
const accPara: Array<number> = []; const accPara: Array<number> = [];
// specific value array // specific value array
const eachTrialParams: Array<string> = []; const eachTrialParams: Array<string> = [];
const paraYdata: number[][] = [];
// experiment interface search space obj // experiment interface search space obj
const searchRange = JSON.parse(searchSpace); const searchRange = searchSpace !== undefined ? JSON.parse(searchSpace) : '';
const dimName = Object.keys(searchRange); const dimName = Object.keys(searchRange);
// trial-jobs interface list if (this._isMounted === true) {
Object.keys(dataSource).map(item => { this.setState(() => ({ dimName: dimName }));
const temp = dataSource[item]; }
if (temp.status === 'SUCCEEDED') {
accPara.push(temp.acc.default); const parallelAxis: Array<Dimobj> = [];
eachTrialParams.push(temp.description.parameters); // search space range and specific value [only number]
for (let i = 0; i < dimName.length; i++) {
const searchKey = searchRange[dimName[i]];
switch (searchKey._type) {
case 'uniform':
case 'quniform':
parallelAxis.push({
dim: i,
name: dimName[i],
max: searchKey._value[1],
min: searchKey._value[0]
});
break;
case 'randint':
parallelAxis.push({
dim: i,
name: dimName[i],
max: searchKey._value[0] - 1,
min: 0
});
break;
case 'choice':
const data: Array<string> = [];
for (let j = 0; j < searchKey._value.length; j++) {
data.push(searchKey._value[j].toString());
}
parallelAxis.push({
dim: i,
name: dimName[i],
type: 'category',
data: data,
boundaryGap: true,
axisLine: {
lineStyle: {
type: 'dotted', // axis type,solid,dashed,dotted
width: 1
}
},
axisTick: {
show: true,
interval: 0,
alignWithLabel: true,
},
axisLabel: {
show: true,
interval: 0,
// rotate: 30
},
});
break;
// support log distribute
case 'loguniform':
parallelAxis.push({
dim: i,
name: dimName[i],
type: 'log',
});
break;
default:
parallelAxis.push({
dim: i,
name: dimName[i]
});
} }
}); }
if (this._isMounted) { if (lenOfDataSource === 0) {
this.setState({ max: Math.max(...accPara), min: Math.min(...accPara) }, () => { const optionOfNull = {
this.getParallelAxis(dimName, searchRange, accPara, eachTrialParams, paraYdata); parallelAxis,
tooltip: {
trigger: 'item'
},
parallel: {
parallelAxisDefault: {
tooltip: {
show: true
},
axisLabel: {
formatter: function (value: string) {
const length = value.length;
if (length > 16) {
const temp = value.split('');
for (let i = 16; i < temp.length; i += 17) {
temp[i] += '\n';
}
return temp.join('');
} else {
return value;
}
}
},
}
},
visualMap: {
type: 'continuous',
min: 0,
max: 1,
color: ['#CA0000', '#FFC400', '#90EE90']
}
};
if (this._isMounted === true) {
this.setState({
paraNodata: 'No data',
option: optionOfNull,
sutrialCount: 0
});
}
} else {
Object.keys(dataSource).map(item => {
const temp = dataSource[item];
eachTrialParams.push(temp.description.parameters);
// may be a succeed trial hasn't final result
// all detail page may be break down if havn't if
if (temp.acc !== undefined) {
if (temp.acc.default !== undefined) {
accPara.push(temp.acc.default);
}
}
}); });
if (this._isMounted) {
this.setState({ max: Math.max(...accPara), min: Math.min(...accPara) }, () => {
this.getParallelAxis(dimName, parallelAxis, accPara, eachTrialParams);
});
}
} }
} }
...@@ -229,9 +284,10 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -229,9 +284,10 @@ class Para extends React.Component<ParaProps, ParaState> {
// deal with response data into pic data // deal with response data into pic data
getOption = (dataObj: ParaObj) => { getOption = (dataObj: ParaObj) => {
// dataObj [[y1], [y2]... [default metric]]
const { max, min } = this.state; const { max, min } = this.state;
let parallelAxis = dataObj.parallelAxis; const parallelAxis = dataObj.parallelAxis;
let paralleData = dataObj.data; const paralleData = dataObj.data;
let visualMapObj = {}; let visualMapObj = {};
if (max === min) { if (max === min) {
visualMapObj = { visualMapObj = {
...@@ -251,7 +307,7 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -251,7 +307,7 @@ class Para extends React.Component<ParaProps, ParaState> {
color: ['#CA0000', '#FFC400', '#90EE90'] color: ['#CA0000', '#FFC400', '#90EE90']
}; };
} }
let optionown = { const optionown = {
parallelAxis, parallelAxis,
tooltip: { tooltip: {
trigger: 'item' trigger: 'item'
...@@ -288,21 +344,11 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -288,21 +344,11 @@ class Para extends React.Component<ParaProps, ParaState> {
} }
}; };
// please wait the data // please wait the data
if (this._isMounted) {
if (paralleData.length === 0) {
this.setState({
paraNodata: 'No data'
});
} else {
this.setState({
paraNodata: ''
});
}
}
// draw search space graph
if (this._isMounted) { if (this._isMounted) {
this.setState(() => ({ this.setState(() => ({
option: optionown option: optionown,
paraNodata: '',
sutrialCount: paralleData.length
})); }));
} }
} }
...@@ -320,6 +366,68 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -320,6 +366,68 @@ class Para extends React.Component<ParaProps, ParaState> {
this.hyperParaPic(dataSource, expSearchSpace); this.hyperParaPic(dataSource, expSearchSpace);
} }
swapReInit = () => {
const { clickCounts } = this.state;
const val = clickCounts + 1;
if (this._isMounted) {
this.setState({ isLoadConfirm: true, clickCounts: val, });
}
const { paraBack, swapAxisArr } = this.state;
const paralDim = paraBack.parallelAxis;
const paraData = paraBack.data;
let temp: number;
let dim1: number;
let dim2: number;
let bool1: boolean = false;
let bool2: boolean = false;
let bool3: boolean = false;
Object.keys(paralDim).map(item => {
const paral = paralDim[item];
switch (paral.name) {
case swapAxisArr[0]:
dim1 = paral.dim;
bool1 = true;
break;
case swapAxisArr[1]:
dim2 = paral.dim;
bool2 = true;
break;
default:
}
if (bool1 && bool2) {
bool3 = true;
}
});
// swap dim's number
Object.keys(paralDim).map(item => {
if (bool3) {
if (paralDim[item].name === swapAxisArr[0]) {
paralDim[item].dim = dim2;
}
if (paralDim[item].name === swapAxisArr[1]) {
paralDim[item].dim = dim1;
}
}
});
paralDim.sort(this.sortDimY);
// swap data array
Object.keys(paraData).map(paraItem => {
temp = paraData[paraItem][dim1];
paraData[paraItem][dim1] = paraData[paraItem][dim2];
paraData[paraItem][dim2] = temp;
});
this.getOption(paraBack);
// please wait the data
if (this._isMounted) {
this.setState(() => ({
isLoadConfirm: false
}));
}
}
sortDimY = (a: Dimobj, b: Dimobj) => { sortDimY = (a: Dimobj, b: Dimobj) => {
return a.dim - b.dim; return a.dim - b.dim;
} }
...@@ -374,11 +482,39 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -374,11 +482,39 @@ class Para extends React.Component<ParaProps, ParaState> {
}); });
} }
componentDidMount() {
this._isMounted = true;
this.reInit();
}
componentWillReceiveProps(nextProps: ParaProps) { componentWillReceiveProps(nextProps: ParaProps) {
const dataSource = nextProps.dataSource; const { dataSource, expSearchSpace, whichGraph } = nextProps;
const expSearchSpace = nextProps.expSearchSpace; if (whichGraph === '2') {
this.hyperParaPic(dataSource, expSearchSpace); this.hyperParaPic(dataSource, expSearchSpace);
}
}
shouldComponentUpdate(nextProps: ParaProps, nextState: ParaState) {
const { whichGraph } = nextProps;
const beforeGraph = this.props.whichGraph;
if (whichGraph === '2') {
if (whichGraph !== beforeGraph) {
return true;
}
const { sutrialCount, clickCounts } = nextState;
const beforeCount = this.state.sutrialCount;
const beforeClickCount = this.state.clickCounts;
if (sutrialCount !== beforeCount) {
return true;
}
if (clickCounts !== beforeClickCount) {
return true;
}
}
return false;
} }
componentWillUnmount() { componentWillUnmount() {
...@@ -386,7 +522,7 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -386,7 +522,7 @@ class Para extends React.Component<ParaProps, ParaState> {
} }
render() { render() {
const { option, paraNodata, dimName } = this.state; const { option, paraNodata, dimName, isLoadConfirm } = this.state;
return ( return (
<Row className="parameter"> <Row className="parameter">
<Row> <Row>
...@@ -423,7 +559,8 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -423,7 +559,8 @@ class Para extends React.Component<ParaProps, ParaState> {
<Button <Button
type="primary" type="primary"
className="changeBtu tableButton" className="changeBtu tableButton"
onClick={this.reInit} onClick={this.swapReInit}
disabled={isLoadConfirm}
> >
Confirm Confirm
</Button> </Button>
...@@ -434,7 +571,7 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -434,7 +571,7 @@ class Para extends React.Component<ParaProps, ParaState> {
<ReactEcharts <ReactEcharts
option={option} option={option}
style={this.chartMulineStyle} style={this.chartMulineStyle}
lazyUpdate={true} // lazyUpdate={true}
notMerge={true} // update now notMerge={true} // update now
/> />
<div className="noneData">{paraNodata}</div> <div className="noneData">{paraNodata}</div>
......
import * as React from 'react'; import * as React from 'react';
import axios from 'axios'; import axios from 'axios';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
import { import { Row, Table, Button, Popconfirm, Modal, Checkbox } from 'antd';
Row, Table, Button, Popconfirm, Modal, Checkbox
} from 'antd';
const CheckboxGroup = Checkbox.Group; const CheckboxGroup = Checkbox.Group;
import { MANAGER_IP, trialJobStatus, COLUMN, COLUMN_INDEX } from '../../static/const'; import { MANAGER_IP, trialJobStatus, COLUMN, COLUMN_INDEX } from '../../static/const';
import { convertDuration, intermediateGraphOption, killJob } from '../../static/function'; import { convertDuration, intermediateGraphOption, killJob } from '../../static/function';
import { TableObj, TrialJob } from '../../static/interface'; import { TableObj, TrialJob } from '../../static/interface';
import OpenRow from '../public-child/OpenRow'; import OpenRow from '../public-child/OpenRow';
// import DefaultMetric from '../public-child/DefaultMetrc'; import IntermediateVal from '../public-child/IntermediateVal'; // table default metric column
import IntermediateVal from '../public-child/IntermediateVal';
import '../../static/style/search.scss'; import '../../static/style/search.scss';
require('../../static/style/tableStatus.css'); require('../../static/style/tableStatus.css');
require('../../static/style/logPath.scss'); require('../../static/style/logPath.scss');
...@@ -33,6 +30,7 @@ interface TableListProps { ...@@ -33,6 +30,7 @@ interface TableListProps {
platform: string; platform: string;
logCollection: boolean; logCollection: boolean;
isMultiPhase: boolean; isMultiPhase: boolean;
isTableLoading: boolean;
} }
interface TableListState { interface TableListState {
...@@ -197,7 +195,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -197,7 +195,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
render() { render() {
const { entries, tableSource, updateList } = this.props; const { entries, tableSource, updateList, isTableLoading } = this.props;
const { intermediateOption, modalVisible, isShowColumn, columnSelected } = this.state; const { intermediateOption, modalVisible, isShowColumn, columnSelected } = this.state;
let showTitle = COLUMN; let showTitle = COLUMN;
let bgColor = ''; let bgColor = '';
...@@ -420,6 +418,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -420,6 +418,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
dataSource={tableSource} dataSource={tableSource}
className="commonTableStyle" className="commonTableStyle"
pagination={{ pageSize: entries }} pagination={{ pageSize: entries }}
loading={isTableLoading}
/> />
{/* Intermediate Result Modal */} {/* Intermediate Result Modal */}
<Modal <Modal
......
import axios from 'axios'; import axios from 'axios';
import { import { message } from 'antd';
message
} from 'antd';
import { MANAGER_IP } from './const'; import { MANAGER_IP } from './const';
import { FinalResult, FinalType } from './interface'; import { FinalResult, FinalType, TableObj } from './interface';
const convertTime = (num: number) => { const convertTime = (num: number) => {
if (num % 3600 === 0) { if (num % 3600 === 0) {
...@@ -131,7 +129,16 @@ const killJob = (key: number, id: string, status: string, updateList: Function) ...@@ -131,7 +129,16 @@ const killJob = (key: number, id: string, status: string, updateList: Function)
}); });
}; };
const filterByStatus = (item: TableObj) => {
return item.status === 'SUCCEEDED';
};
// a waittiong trial may havn't start time
const filterDuration = (item: TableObj) => {
return item.status !== 'WAITING';
};
export { export {
convertTime, convertDuration, getFinalResult, convertTime, convertDuration, getFinalResult, getFinal,
getFinal, intermediateGraphOption, killJob intermediateGraphOption, killJob, filterByStatus, filterDuration
}; };
...@@ -26,7 +26,7 @@ interface ErrorParameter { ...@@ -26,7 +26,7 @@ interface ErrorParameter {
interface Parameters { interface Parameters {
parameters: ErrorParameter; parameters: ErrorParameter;
logPath?: string; logPath?: string;
intermediate?: Array<number>; intermediate: Array<number>;
} }
interface Experiment { interface Experiment {
......
$color: #f2f2f2;
.formatStr{
border: 1px solid #8f8f8f;
color: #333;
padding: 5px 10px;
background-color: #fff;
}
.format {
.ant-modal-header{
background-color: $color;
border-bottom: none;
}
.ant-modal-footer{
background-color: $color;
border-top: none;
}
.ant-modal-body{
background-color: $color;
padding: 10px 24px !important;
}
}
...@@ -52,4 +52,3 @@ ...@@ -52,4 +52,3 @@
.link{ .link{
margin-bottom: 10px; margin-bottom: 10px;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment