Unverified Commit 611a45fc authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #19 from microsoft/master

pull code
parents 841d4677 e267a737
......@@ -26,7 +26,7 @@ import random
import numpy as np
from nni.tuner import Tuner
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index, randint_to_quniform
import nni.parameter_expressions as parameter_expressions
......@@ -175,6 +175,7 @@ class EvolutionTuner(Tuner):
search_space : dict
"""
self.searchspace_json = search_space
randint_to_quniform(self.searchspace_json)
self.space = json2space(self.searchspace_json)
self.random_state = np.random.RandomState()
......
......@@ -31,7 +31,7 @@ import json_tricks
from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.common import init_logger
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, randint_to_quniform
import nni.parameter_expressions as parameter_expressions
_logger = logging.getLogger(__name__)
......@@ -357,6 +357,7 @@ class Hyperband(MsgDispatcherBase):
number of trial jobs
"""
self.searchspace_json = data
randint_to_quniform(self.searchspace_json)
self.random_state = np.random.RandomState()
def handle_trial_end(self, data):
......
......@@ -27,7 +27,7 @@ import logging
import hyperopt as hp
import numpy as np
from nni.tuner import Tuner
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index, randint_to_quniform
logger = logging.getLogger('hyperopt_AutoML')
......@@ -153,14 +153,14 @@ def _add_index(in_x, parameter):
Will change to format in hyperopt, like:
{'dropout_rate': 0.8, 'conv_size': {'_index': 1, '_value': 3}, 'hidden_size': {'_index': 1, '_value': 512}}
"""
if TYPE not in in_x: # if at the top level
if NodeType.TYPE not in in_x: # if at the top level
out_y = dict()
for key, value in parameter.items():
out_y[key] = _add_index(in_x[key], value)
return out_y
elif isinstance(in_x, dict):
value_type = in_x[TYPE]
value_format = in_x[VALUE]
value_type = in_x[NodeType.TYPE]
value_format = in_x[NodeType.VALUE]
if value_type == "choice":
choice_name = parameter[0] if isinstance(parameter,
list) else parameter
......@@ -173,15 +173,14 @@ def _add_index(in_x, parameter):
choice_value_format = item[1]
if choice_key == choice_name:
return {
INDEX:
pos,
VALUE: [
NodeType.INDEX: pos,
NodeType.VALUE: [
choice_name,
_add_index(choice_value_format, parameter[1])
]
}
elif choice_name == item:
return {INDEX: pos, VALUE: item}
return {NodeType.INDEX: pos, NodeType.VALUE: item}
else:
return parameter
......@@ -232,6 +231,8 @@ class HyperoptTuner(Tuner):
search_space : dict
"""
self.json = search_space
randint_to_quniform(self.json)
search_space_instance = json2space(self.json)
rstate = np.random.RandomState()
trials = hp.Trials()
......
......@@ -133,7 +133,7 @@ class MetisTuner(Tuner):
self.x_bounds[idx] = bounds
self.x_types[idx] = 'discrete_int'
elif key_type == 'randint':
self.x_bounds[idx] = [0, key_range[0]]
self.x_bounds[idx] = [key_range[0], key_range[1]]
self.x_types[idx] = 'range_int'
elif key_type == 'uniform':
self.x_bounds[idx] = [key_range[0], key_range[1]]
......
......@@ -21,21 +21,24 @@
smac_tuner.py
"""
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
import sys
import logging
import numpy as np
import json_tricks
from enum import Enum, unique
from .convert_ss_to_scenario import generate_scenario
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
from smac.utils.io.cmd_reader import CMDReader
from smac.scenario.scenario import Scenario
from smac.facade.smac_facade import SMAC
from smac.facade.roar_facade import ROAR
from smac.facade.epils_facade import EPILS
from ConfigSpaceNNI import Configuration
from .convert_ss_to_scenario import generate_scenario
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward, randint_to_quniform
class SMACTuner(Tuner):
......@@ -57,6 +60,7 @@ class SMACTuner(Tuner):
self.update_ss_done = False
self.loguniform_key = set()
self.categorical_dict = {}
self.cs = None
def _main_cli(self):
"""Main function of SMAC for CLI interface
......@@ -66,7 +70,7 @@ class SMACTuner(Tuner):
instance
optimizer
"""
self.logger.info("SMAC call: %s" % (" ".join(sys.argv)))
self.logger.info("SMAC call: %s", " ".join(sys.argv))
cmd_reader = CMDReader()
args, _ = cmd_reader.read_cmd()
......@@ -95,6 +99,7 @@ class SMACTuner(Tuner):
# Create scenario-object
scen = Scenario(args.scenario_file, [])
self.cs = scen.cs
if args.mode == "SMAC":
optimizer = SMAC(
......@@ -134,6 +139,7 @@ class SMACTuner(Tuner):
search_space:
search space
"""
randint_to_quniform(search_space)
if not self.update_ss_done:
self.categorical_dict = generate_scenario(search_space)
if self.categorical_dict is None:
......@@ -258,4 +264,45 @@ class SMACTuner(Tuner):
return params
def import_data(self, data):
pass
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
_completed_num = 0
for trial_info in data:
self.logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
# simply validate data format
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
self.logger.info("Useless trial data, value is %s, skip this trial data.", _value)
continue
# convert the keys in loguniform and categorical types
valid_entry = True
for key, value in _params.items():
if key in self.loguniform_key:
_params[key] = np.log(value)
elif key in self.categorical_dict:
if value in self.categorical_dict[key]:
_params[key] = self.categorical_dict[key].index(value)
else:
self.logger.info("The value %s of key %s is not in search space.", str(value), key)
valid_entry = False
break
if not valid_entry:
continue
# start import this data entry
_completed_num += 1
config = Configuration(self.cs, values=_params)
if self.optimize_mode is OptimizeMode.Maximize:
_value = -_value
if self.first_one:
self.smbo_solver.nni_smac_receive_first_run(config, _value)
self.first_one = False
else:
self.smbo_solver.nni_smac_receive_runs(config, _value)
self.logger.info("Successfully import data to smac tuner, total data: %d, imported data: %d.", len(data), _completed_num)
......@@ -36,7 +36,8 @@ __all__ = [
'qnormal',
'lognormal',
'qlognormal',
'function_choice'
'function_choice',
'mutable_layer'
]
......@@ -78,6 +79,9 @@ if trial_env_vars.NNI_PLATFORM is None:
def function_choice(*funcs, name=None):
return random.choice(funcs)()
def mutable_layer():
raise RuntimeError('Cannot call nni.mutable_layer in this mode')
else:
def choice(options, name=None, key=None):
......@@ -113,6 +117,42 @@ else:
def function_choice(funcs, name=None, key=None):
return funcs[_get_param(key)]()
def mutable_layer(
mutable_id,
mutable_layer_id,
funcs,
funcs_args,
fixed_inputs,
optional_inputs,
optional_input_size=0):
'''execute the chosen function and inputs.
Below is an example of chosen function and inputs:
{
"mutable_id": {
"mutable_layer_id": {
"chosen_layer": "pool",
"chosen_inputs": ["out1", "out3"]
}
}
}
Parameters:
---------------
mutable_id: the name of this mutable_layer block (which could have multiple mutable layers)
mutable_layer_id: the name of a mutable layer in this block
funcs: dict of function calls
funcs_args:
fixed_inputs:
optional_inputs: dict of optional inputs
optional_input_size: number of candidate inputs to be chosen
'''
mutable_block = _get_param(mutable_id)
chosen_layer = mutable_block[mutable_layer_id]["chosen_layer"]
chosen_inputs = mutable_block[mutable_layer_id]["chosen_inputs"]
real_chosen_inputs = [optional_inputs[input_name] for input_name in chosen_inputs]
layer_out = funcs[chosen_layer]([fixed_inputs, real_chosen_inputs], *funcs_args[chosen_layer])
return layer_out
def _get_param(key):
if trial._params is None:
trial.get_next_parameter()
......
......@@ -40,6 +40,7 @@ class OptimizeMode(Enum):
Minimize = 'minimize'
Maximize = 'maximize'
class NodeType:
"""Node Type class
"""
......@@ -83,6 +84,7 @@ def extract_scalar_reward(value, scalar_key='default'):
raise RuntimeError('Incorrect final result: the final result should be float/int, or a dict which has a key named "default" whose value is float/int.')
return reward
def convert_dict2tuple(value):
"""
convert dict type to tuple to solve unhashable problem.
......@@ -94,9 +96,30 @@ def convert_dict2tuple(value):
else:
return value
def init_dispatcher_logger():
""" Initialize dispatcher logging configuration"""
logger_file_path = 'dispatcher.log'
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
init_logger(logger_file_path, dispatcher_env_vars.NNI_LOG_LEVEL)
def randint_to_quniform(in_x):
if isinstance(in_x, dict):
if NodeType.TYPE in in_x.keys():
if in_x[NodeType.TYPE] == 'randint':
value = in_x[NodeType.VALUE]
value.append(1)
in_x[NodeType.TYPE] = 'quniform'
in_x[NodeType.VALUE] = value
elif in_x[NodeType.TYPE] == 'choice':
randint_to_quniform(in_x[NodeType.VALUE])
else:
for key in in_x.keys():
randint_to_quniform(in_x[key])
elif isinstance(in_x, list):
for temp in in_x:
randint_to_quniform(temp)
......@@ -192,19 +192,23 @@ class Overview extends React.Component<{}, OverviewState> {
method: 'GET'
})
.then(res => {
if (res.status === 200 && this._isMounted) {
if (res.status === 200) {
const errors = res.data.errors;
if (errors.length !== 0) {
if (this._isMounted) {
this.setState({
status: res.data.status,
errorStr: res.data.errors[0]
});
}
} else {
if (this._isMounted) {
this.setState({
status: res.data.status,
});
}
}
}
});
}
......@@ -254,7 +258,8 @@ class Overview extends React.Component<{}, OverviewState> {
case 'SUCCEEDED':
profile.succTrial += 1;
const desJobDetail: Parameters = {
parameters: {}
parameters: {},
intermediate: []
};
const duration = (tableData[item].endTime - tableData[item].startTime) / 1000;
const acc = getFinal(tableData[item].finalMetricData);
......
......@@ -27,6 +27,11 @@ interface TrialDetailState {
entriesInSelect: string;
searchSpace: string;
isMultiPhase: boolean;
isTableLoading: boolean;
whichGraph: string;
hyperCounts: number; // user click the hyper-parameter counts
durationCounts: number;
intermediateCounts: number;
}
class TrialsDetail extends React.Component<{}, TrialDetailState> {
......@@ -70,9 +75,14 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
experimentLogCollection: false,
entriesTable: 20,
entriesInSelect: '20',
isHasSearch: false,
searchSpace: '',
isMultiPhase: false
whichGraph: '1',
isHasSearch: false,
isMultiPhase: false,
isTableLoading: false,
hyperCounts: 0,
durationCounts: 0,
intermediateCounts: 0
};
}
......@@ -85,6 +95,9 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
])
.then(axios.spread((res, res1) => {
if (res.status === 200 && res1.status === 200) {
if (this._isMounted === true) {
this.setState(() => ({ isTableLoading: true }));
}
const trialJobs = res.data;
const metricSource = res1.data;
const trialTable: Array<TableObj> = [];
......@@ -175,6 +188,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
}
if (this._isMounted) {
this.setState(() => ({
isTableLoading: false,
tableListSource: trialTable
}));
}
......@@ -239,17 +253,12 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
}
handleEntriesSelect = (value: string) => {
switch (value) {
case '20':
this.setState(() => ({ entriesTable: 20 }));
break;
case '50':
this.setState(() => ({ entriesTable: 50 }));
break;
case '100':
this.setState(() => ({ entriesTable: 100 }));
break;
case 'all':
// user select isn't 'all'
if (value !== 'all') {
if (this._isMounted) {
this.setState(() => ({ entriesTable: parseInt(value, 10) }));
}
} else {
const { tableListSource } = this.state;
if (this._isMounted) {
this.setState(() => ({
......@@ -257,8 +266,13 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
entriesTable: tableListSource.length
}));
}
break;
default:
}
}
handleWhichTabs = (activeKey: string) => {
// const which = JSON.parse(activeKey);
if (this._isMounted) {
this.setState(() => ({ whichGraph: activeKey }));
}
}
......@@ -315,18 +329,21 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
const {
tableListSource, searchResultSource, isHasSearch, isMultiPhase,
entriesTable, experimentPlatform, searchSpace, experimentLogCollection
entriesTable, experimentPlatform, searchSpace, experimentLogCollection,
whichGraph, isTableLoading
} = this.state;
const source = isHasSearch ? searchResultSource : tableListSource;
return (
<div>
<div className="trial" id="tabsty">
<Tabs type="card">
<Tabs type="card" onChange={this.handleWhichTabs}>
{/* <TabPane tab={this.titleOfacc} key="1" destroyInactiveTabPane={true}> */}
<TabPane tab={this.titleOfacc} key="1">
<Row className="graph">
<DefaultPoint
height={432}
showSource={source}
whichGraph={whichGraph}
/>
</Row>
</TabPane>
......@@ -335,14 +352,16 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
<Para
dataSource={source}
expSearchSpace={searchSpace}
whichGraph={whichGraph}
/>
</Row>
</TabPane>
<TabPane tab={this.titleOfDuration} key="3">
<Duration source={source} />
<Duration source={source} whichGraph={whichGraph} />
{/* <Duration source={source} whichGraph={whichGraph} clickCounts={durationCounts} /> */}
</TabPane>
<TabPane tab={this.titleOfIntermediate} key="4">
<Intermediate source={source} />
<Intermediate source={source} whichGraph={whichGraph} />
</TabPane>
</Tabs>
</div>
......@@ -388,6 +407,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
<TableList
entries={entriesTable}
tableSource={source}
isTableLoading={isTableLoading}
isMultiPhase={isMultiPhase}
platform={experimentPlatform}
updateList={this.getDetailSource}
......
......@@ -28,12 +28,11 @@ class IntermediateVal extends React.Component<IntermediateValProps, {}> {
if (wei > 6) {
result = `${lastVal.toFixed(6)}`;
}
}
if (status === 'SUCCEEDED') {
result = `${lastVal.toFixed(6)} (FINAL)`;
result = `${result} (FINAL)`;
} else {
result = `${lastVal.toFixed(6)} (LATEST)`;
}
result = `${result} (LATEST)`;
}
} else {
result = '--';
......
......@@ -3,9 +3,10 @@ import * as copy from 'copy-to-clipboard';
import PaiTrialLog from '../public-child/PaiTrialLog';
import TrialLog from '../public-child/TrialLog';
import { TableObj } from '../../static/interface';
import { Row, Tabs, Button, message } from 'antd';
import { Row, Tabs, Button, message, Modal } from 'antd';
import { MANAGER_IP } from '../../static/const';
import '../../static/style/overview.scss';
import '../../static/style/copyParameter.scss';
import JSONTree from 'react-json-tree';
const TabPane = Tabs.TabPane;
......@@ -17,43 +18,62 @@ interface OpenRowProps {
}
interface OpenRowState {
idList: Array<string>;
isShowFormatModal: boolean;
formatStr: string;
}
class OpenRow extends React.Component<OpenRowProps, OpenRowState> {
public _isMounted: boolean;
constructor(props: OpenRowProps) {
super(props);
this.state = {
idList: ['']
isShowFormatModal: false,
formatStr: ''
};
}
showFormatModal = (record: TableObj) => {
// get copy parameters
const params = JSON.stringify(record.description.parameters, null, 4);
// open modal with format string
if (this._isMounted === true) {
this.setState(() => ({ isShowFormatModal: true, formatStr: params }));
}
}
hideFormatModal = () => {
// close modal, destroy state format string data
if (this._isMounted === true) {
this.setState(() => ({ isShowFormatModal: false, formatStr: '' }));
}
}
copyParams = (record: TableObj) => {
copyParams = () => {
// json format
const params = JSON.stringify(record.description.parameters, null, 4);
if (copy(params)) {
const { formatStr } = this.state;
if (copy(formatStr)) {
message.destroy();
message.success('Success copy parameters to clipboard in form of python dict !', 3);
const { idList } = this.state;
const copyIdList: Array<string> = idList;
copyIdList[copyIdList.length - 1] = record.id;
this.setState(() => ({
idList: copyIdList
}));
} else {
message.destroy();
message.error('Failed !', 2);
}
this.hideFormatModal();
}
componentDidMount() {
this._isMounted = true;
}
componentWillUnmount() {
this._isMounted = false;
}
render() {
const { trainingPlatform, record, logCollection, multiphase } = this.props;
const { idList } = this.state;
const { isShowFormatModal, formatStr } = this.state;
let isClick = false;
let isHasParameters = true;
if (idList.indexOf(record.id) !== -1) { isClick = true; }
if (record.description.parameters.error) {
isHasParameters = false;
}
......@@ -101,7 +121,7 @@ class OpenRow extends React.Component<OpenRowProps, OpenRowState> {
</Row>
<Row className="copy">
<Button
onClick={this.copyParams.bind(this, record)}
onClick={this.showFormatModal.bind(this, record)}
>
Copy as python
</Button>
......@@ -128,6 +148,21 @@ class OpenRow extends React.Component<OpenRowProps, OpenRowState> {
}
</TabPane>
</Tabs>
<Modal
title="Format"
okText="Copy"
centered={true}
visible={isShowFormatModal}
onCancel={this.hideFormatModal}
maskClosable={false} // click mongolian layer don't close modal
onOk={this.copyParams}
destroyOnClose={true}
width="60%"
className="format"
>
{/* write string in pre to show format string */}
<pre className="formatStr">{formatStr}</pre>
</Modal>
</Row >
);
}
......
import * as React from 'react';
import ReactEcharts from 'echarts-for-react';
import { filterByStatus } from '../../static/function';
import { TableObj, DetailAccurPoint, TooltipForAccuracy } from '../../static/interface';
require('echarts/lib/chart/scatter');
require('echarts/lib/component/tooltip');
......@@ -8,11 +9,13 @@ require('echarts/lib/component/title');
interface DefaultPointProps {
showSource: Array<TableObj>;
height: number;
whichGraph: string;
}
interface DefaultPointState {
defaultSource: object;
accNodata: string;
succeedTrials: number;
}
class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> {
......@@ -22,15 +25,46 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
super(props);
this.state = {
defaultSource: {},
accNodata: 'No data'
accNodata: '',
succeedTrials: 10000000
};
}
defaultMetric = (showSource: Array<TableObj>) => {
defaultMetric = (succeedSource: Array<TableObj>) => {
const accSource: Array<DetailAccurPoint> = [];
const showSource: Array<TableObj> = succeedSource.filter(filterByStatus);
const lengthOfSource = showSource.length;
const tooltipDefault = lengthOfSource === 0 ? 'No data' : '';
if (this._isMounted === true) {
this.setState(() => ({
succeedTrials: lengthOfSource,
accNodata: tooltipDefault
}));
}
if (lengthOfSource === 0) {
const nullGraph = {
grid: {
left: '8%'
},
xAxis: {
name: 'Trial',
type: 'category',
},
yAxis: {
name: 'Default metric',
type: 'value',
}
};
if (this._isMounted === true) {
this.setState(() => ({
defaultSource: nullGraph
}));
}
} else {
const resultList: Array<number | string>[] = [];
Object.keys(showSource).map(item => {
const temp = showSource[item];
if (temp.status === 'SUCCEEDED' && temp.acc !== undefined) {
if (temp.acc !== undefined) {
if (temp.acc.default !== undefined) {
const searchSpace = temp.description.parameters;
accSource.push({
......@@ -41,7 +75,6 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}
}
});
const resultList: Array<number | string>[] = [];
Object.keys(accSource).map(item => {
const items = accSource[item];
let temp: Array<number | string>;
......@@ -89,25 +122,34 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}]
};
if (this._isMounted === true) {
this.setState({ defaultSource: allAcuracy }, () => {
if (resultList.length === 0) {
this.setState({
accNodata: 'No data'
});
} else {
this.setState({
accNodata: ''
});
this.setState(() => ({
defaultSource: allAcuracy
}));
}
});
}
}
// update parent component state
componentWillReceiveProps(nextProps: DefaultPointProps) {
const showSource = nextProps.showSource;
const { whichGraph, showSource } = nextProps;
if (whichGraph === '1') {
this.defaultMetric(showSource);
}
}
shouldComponentUpdate(nextProps: DefaultPointProps, nextState: DefaultPointState) {
const { whichGraph } = nextProps;
const succTrial = this.state.succeedTrials;
const { succeedTrials } = nextState;
if (whichGraph === '1') {
if (succeedTrials !== succTrial) {
return true;
}
}
// only whichGraph !== '1', default metric can't update
return false;
}
componentDidMount() {
this._isMounted = true;
......@@ -131,6 +173,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}}
theme="my_theme"
notMerge={true} // update now
// lazyUpdate={true}
/>
<div className="showMess">{accNodata}</div>
</div>
......
import * as React from 'react';
import ReactEcharts from 'echarts-for-react';
import { TableObj } from 'src/static/interface';
import { filterDuration } from 'src/static/function';
require('echarts/lib/chart/bar');
require('echarts/lib/component/tooltip');
require('echarts/lib/component/title');
......@@ -12,6 +13,7 @@ interface Runtrial {
interface DurationProps {
source: Array<TableObj>;
whichGraph: string;
}
interface DurationState {
......@@ -26,11 +28,62 @@ class Duration extends React.Component<DurationProps, DurationState> {
super(props);
this.state = {
durationSource: {}
durationSource: this.initDuration(this.props.source),
};
}
initDuration = (source: Array<TableObj>) => {
const trialId: Array<string> = [];
const trialTime: Array<number> = [];
const trialJobs = source.filter(filterDuration);
Object.keys(trialJobs).map(item => {
const temp = trialJobs[item];
trialId.push(temp.sequenceId);
trialTime.push(temp.duration);
});
return {
tooltip: {
trigger: 'axis',
axisPointer: {
type: 'shadow'
}
},
grid: {
bottom: '3%',
containLabel: true,
left: '1%',
right: '4%'
},
dataZoom: [{
type: 'slider',
name: 'trial',
filterMode: 'filter',
yAxisIndex: 0,
orient: 'vertical'
}, {
type: 'slider',
name: 'trial',
filterMode: 'filter',
xAxisIndex: 0
}],
xAxis: {
name: 'Time',
type: 'value',
},
yAxis: {
name: 'Trial',
type: 'category',
data: trialId
},
series: [{
type: 'bar',
data: trialTime
}]
};
}
getOption = (dataObj: Runtrial) => {
return {
tooltip: {
......@@ -74,17 +127,16 @@ class Duration extends React.Component<DurationProps, DurationState> {
};
}
drawDurationGraph = (trialJobs: Array<TableObj>) => {
drawDurationGraph = (source: Array<TableObj>) => {
// why this function run two times when props changed?
const trialId: Array<string> = [];
const trialTime: Array<number> = [];
const trialRun: Array<Runtrial> = [];
const trialJobs = source.filter(filterDuration);
Object.keys(trialJobs).map(item => {
const temp = trialJobs[item];
if (temp.status !== 'WAITING') {
trialId.push(temp.sequenceId);
trialTime.push(temp.duration);
}
});
trialRun.push({
trialId: trialId,
......@@ -97,17 +149,42 @@ class Duration extends React.Component<DurationProps, DurationState> {
}
}
componentWillReceiveProps(nextProps: DurationProps) {
const trialJobs = nextProps.source;
this.drawDurationGraph(trialJobs);
}
componentDidMount() {
this._isMounted = true;
// init: user don't search
const {source} = this.props;
const { source } = this.props;
this.drawDurationGraph(source);
}
componentWillReceiveProps(nextProps: DurationProps) {
const { whichGraph, source } = nextProps;
if (whichGraph === '3') {
this.drawDurationGraph(source);
}
}
shouldComponentUpdate(nextProps: DurationProps, nextState: DurationState) {
const { whichGraph, source } = nextProps;
if (whichGraph === '3') {
const beforeSource = this.props.source;
if (whichGraph !== this.props.whichGraph) {
return true;
}
if (source.length !== beforeSource.length) {
return true;
}
if (source[source.length - 1].duration !== beforeSource[beforeSource.length - 1].duration) {
return true;
}
if (source[source.length - 1].status !== beforeSource[beforeSource.length - 1].status) {
return true;
}
}
return false;
}
componentWillUnmount() {
this._isMounted = false;
......@@ -121,6 +198,7 @@ class Duration extends React.Component<DurationProps, DurationState> {
option={durationSource}
style={{ width: '95%', height: 412, margin: '0 auto' }}
theme="my_theme"
notMerge={true} // update now
/>
</div>
);
......
......@@ -11,16 +11,21 @@ interface Intermedia {
data: Array<number | object>; // intermediate data
hyperPara: object; // each trial hyperpara value
}
interface IntermediateState {
detailSource: Array<TableObj>;
interSource: object;
filterSource: Array<TableObj>;
eachIntermediateNum: number; // trial's intermediate number count
isLoadconfirmBtn: boolean;
isFilter: boolean;
length: number;
clickCounts: number; // user filter intermediate click confirm btn's counts
}
interface IntermediateProps {
source: Array<TableObj>;
whichGraph: string;
}
class Intermediate extends React.Component<IntermediateProps, IntermediateState> {
......@@ -34,39 +39,25 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
constructor(props: IntermediateProps) {
super(props);
this.state = {
detailSource: [],
interSource: {},
filterSource: [],
eachIntermediateNum: 1,
isLoadconfirmBtn: false,
isFilter: false
isFilter: false,
length: 100000,
clickCounts: 0
};
}
initMediate = () => {
const option = {
grid: {
left: '5%',
top: 40,
containLabel: true
},
xAxis: {
type: 'category',
boundaryGap: false,
},
yAxis: {
type: 'value',
name: 'Scape'
}
};
drawIntermediate = (source: Array<TableObj>) => {
if (source.length > 0) {
if (this._isMounted) {
this.setState(() => ({
interSource: option
length: source.length,
detailSource: source
}));
}
}
drawIntermediate = (source: Array<TableObj>) => {
if (source.length > 0) {
const trialIntermediate: Array<Intermedia> = [];
Object.keys(source).map(item => {
const temp = source[item];
......@@ -140,7 +131,24 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
}));
}
} else {
this.initMediate();
const nullData = {
grid: {
left: '5%',
top: 40,
containLabel: true
},
xAxis: {
type: 'category',
boundaryGap: false,
},
yAxis: {
type: 'value',
name: 'Scape'
}
};
if (this._isMounted) {
this.setState(() => ({ interSource: nullData }));
}
}
}
......@@ -183,8 +191,9 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
this.setState({ filterSource: filterSource });
}
this.drawIntermediate(filterSource);
const counts = this.state.clickCounts + 1;
this.setState({ isLoadconfirmBtn: false, clickCounts: counts });
}
this.setState({ isLoadconfirmBtn: false });
});
}
}
......@@ -204,19 +213,65 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
this.drawIntermediate(source);
}
componentWillReceiveProps(nextProps: IntermediateProps) {
const { isFilter, filterSource } = this.state;
componentWillReceiveProps(nextProps: IntermediateProps, nextState: IntermediateState) {
const { isFilter, filterSource } = nextState;
const { whichGraph, source } = nextProps;
if (whichGraph === '4') {
if (isFilter === true) {
const pointVal = this.pointInput !== null ? this.pointInput.value : '';
const minVal = this.minValInput !== null ? this.minValInput.value : '';
if (pointVal === '' && minVal === '') {
this.drawIntermediate(nextProps.source);
this.drawIntermediate(source);
} else {
this.drawIntermediate(filterSource);
}
} else {
this.drawIntermediate(nextProps.source);
this.drawIntermediate(source);
}
}
}
shouldComponentUpdate(nextProps: IntermediateProps, nextState: IntermediateState) {
const { whichGraph } = nextProps;
const beforeGraph = this.props.whichGraph;
if (whichGraph === '4') {
const { source } = nextProps;
const { isFilter, length, clickCounts } = nextState;
const beforeLength = this.state.length;
const beforeSource = this.state.detailSource;
const beforeClickCounts = this.state.clickCounts;
if (isFilter !== this.state.isFilter) {
return true;
}
if (clickCounts !== beforeClickCounts) {
return true;
}
if (isFilter === false) {
if (whichGraph !== beforeGraph) {
return true;
}
if (length !== beforeLength) {
return true;
}
if (source[source.length - 1].description.intermediate.length !==
beforeSource[beforeSource.length - 1].description.intermediate.length) {
return true;
}
if (source[source.length - 1].duration !== beforeSource[beforeSource.length - 1].duration) {
return true;
}
if (source[source.length - 1].status !== beforeSource[beforeSource.length - 1].status) {
return true;
}
}
}
return false;
}
componentWillUnmount() {
......@@ -225,7 +280,6 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
render() {
const { interSource, isLoadconfirmBtn, isFilter } = this.state;
return (
<div>
{/* style in para.scss */}
......
import * as React from 'react';
import ReactEcharts from 'echarts-for-react';
import { filterByStatus } from '../../static/function';
import { Row, Col, Select, Button, message } from 'antd';
import { ParaObj, Dimobj, TableObj, SearchSpace } from '../../static/interface';
import { ParaObj, Dimobj, TableObj } from '../../static/interface';
const Option = Select.Option;
require('echarts/lib/chart/parallel');
require('echarts/lib/component/tooltip');
......@@ -11,6 +12,7 @@ require('../../static/style/para.scss');
require('../../static/style/button.scss');
interface ParaState {
// paraSource: Array<TableObj>;
option: object;
paraBack: ParaObj;
dimName: Array<string>;
......@@ -19,11 +21,15 @@ interface ParaState {
paraNodata: string;
max: number; // graph color bar limit
min: number;
sutrialCount: number; // succeed trial numbers for SUC
clickCounts: number;
isLoadConfirm: boolean;
}
interface ParaProps {
dataSource: Array<TableObj>;
expSearchSpace: string;
whichGraph: string;
}
message.config({
......@@ -45,6 +51,8 @@ class Para extends React.Component<ParaProps, ParaState> {
constructor(props: ParaProps) {
super(props);
this.state = {
// paraSource: [],
// option: this.hyperParaPic,
option: {},
dimName: [],
paraBack: {
......@@ -58,27 +66,80 @@ class Para extends React.Component<ParaProps, ParaState> {
percent: 0,
paraNodata: '',
min: 0,
max: 1
max: 1,
sutrialCount: 10000000,
clickCounts: 1,
isLoadConfirm: false
};
}
componentDidMount() {
this._isMounted = true;
this.reInit();
}
getParallelAxis =
(
dimName: Array<string>, searchRange: SearchSpace,
accPara: Array<number>,
eachTrialParams: Array<string>, paraYdata: number[][]
dimName: Array<string>, parallelAxis: Array<Dimobj>,
accPara: Array<number>, eachTrialParams: Array<string>
) => {
if (this._isMounted) {
this.setState(() => ({
dimName: dimName
}));
// get data for every lines. if dim is choice type, number -> toString()
const paraYdata: number[][] = [];
Object.keys(eachTrialParams).map(item => {
let temp: Array<number> = [];
for (let i = 0; i < dimName.length; i++) {
if ('type' in parallelAxis[i]) {
temp.push(
eachTrialParams[item][dimName[i]].toString()
);
} else {
temp.push(
eachTrialParams[item][dimName[i]]
);
}
}
paraYdata.push(temp);
});
// add acc
Object.keys(paraYdata).map(item => {
paraYdata[item].push(accPara[item]);
});
// according acc to sort ydata // sort to find top percent dataset
if (paraYdata.length !== 0) {
const len = paraYdata[0].length - 1;
paraYdata.sort((a, b) => b[len] - a[len]);
}
const paraData = {
parallelAxis: parallelAxis,
data: paraYdata
};
const { percent, swapAxisArr } = this.state;
// need to cut down the data
if (percent !== 0) {
const linesNum = paraData.data.length;
// Math.ceil rather than Math.floor to avoid lost lines
const len = Math.ceil(linesNum * percent);
paraData.data.length = len;
}
// need to swap the yAxis
if (swapAxisArr.length >= 2) {
this.swapGraph(paraData, swapAxisArr);
}
this.getOption(paraData);
if (this._isMounted === true) {
this.setState(() => ({ paraBack: paraData }));
}
}
hyperParaPic = (source: Array<TableObj>, searchSpace: string) => {
// filter succeed trials [{}, {}, {}]
const dataSource: Array<TableObj> = source.filter(filterByStatus);
const lenOfDataSource: number = dataSource.length;
const accPara: Array<number> = [];
// specific value array
const eachTrialParams: Array<string> = [];
// experiment interface search space obj
const searchRange = searchSpace !== undefined ? JSON.parse(searchSpace) : '';
const dimName = Object.keys(searchRange);
if (this._isMounted === true) {
this.setState(() => ({ dimName: dimName }));
}
const parallelAxis: Array<Dimobj> = [];
// search space range and specific value [only number]
for (let i = 0; i < dimName.length; i++) {
......@@ -149,72 +210,66 @@ class Para extends React.Component<ParaProps, ParaState> {
}
}
// get data for every lines. if dim is choice type, number -> toString()
Object.keys(eachTrialParams).map(item => {
let temp: Array<number> = [];
for (let i = 0; i < dimName.length; i++) {
if ('type' in parallelAxis[i]) {
temp.push(
eachTrialParams[item][dimName[i]].toString()
);
} else {
temp.push(
eachTrialParams[item][dimName[i]]
);
if (lenOfDataSource === 0) {
const optionOfNull = {
parallelAxis,
tooltip: {
trigger: 'item'
},
parallel: {
parallelAxisDefault: {
tooltip: {
show: true
},
axisLabel: {
formatter: function (value: string) {
const length = value.length;
if (length > 16) {
const temp = value.split('');
for (let i = 16; i < temp.length; i += 17) {
temp[i] += '\n';
}
return temp.join('');
} else {
return value;
}
paraYdata.push(temp);
});
// add acc
Object.keys(paraYdata).map(item => {
paraYdata[item].push(accPara[item]);
});
// according acc to sort ydata
if (paraYdata.length !== 0) {
const len = paraYdata[0].length - 1;
paraYdata.sort((a, b) => b[len] - a[len]);
}
const paraData = {
parallelAxis: parallelAxis,
data: paraYdata
};
const { percent, swapAxisArr } = this.state;
// need to cut down the data
if (percent !== 0) {
const linesNum = paraData.data.length;
// Math.ceil rather than Math.floor to avoid lost lines
const len = Math.ceil(linesNum * percent);
paraData.data.length = len;
},
}
// need to swap the yAxis
if (swapAxisArr.length >= 2) {
this.swapGraph(paraData, swapAxisArr);
},
visualMap: {
type: 'continuous',
min: 0,
max: 1,
color: ['#CA0000', '#FFC400', '#90EE90']
}
this.getOption(paraData);
};
if (this._isMounted === true) {
this.setState({
paraNodata: 'No data',
option: optionOfNull,
sutrialCount: 0
});
}
hyperParaPic = (dataSource: Array<TableObj>, searchSpace: string) => {
const accPara: Array<number> = [];
// specific value array
const eachTrialParams: Array<string> = [];
const paraYdata: number[][] = [];
// experiment interface search space obj
const searchRange = JSON.parse(searchSpace);
const dimName = Object.keys(searchRange);
// trial-jobs interface list
} else {
Object.keys(dataSource).map(item => {
const temp = dataSource[item];
if (temp.status === 'SUCCEEDED') {
accPara.push(temp.acc.default);
eachTrialParams.push(temp.description.parameters);
// may be a succeed trial hasn't final result
// all detail page may be break down if havn't if
if (temp.acc !== undefined) {
if (temp.acc.default !== undefined) {
accPara.push(temp.acc.default);
}
}
});
if (this._isMounted) {
this.setState({ max: Math.max(...accPara), min: Math.min(...accPara) }, () => {
this.getParallelAxis(dimName, searchRange, accPara, eachTrialParams, paraYdata);
this.getParallelAxis(dimName, parallelAxis, accPara, eachTrialParams);
});
}
}
}
// get percent value number
percentNum = (value: string) => {
......@@ -229,9 +284,10 @@ class Para extends React.Component<ParaProps, ParaState> {
// deal with response data into pic data
getOption = (dataObj: ParaObj) => {
// dataObj [[y1], [y2]... [default metric]]
const { max, min } = this.state;
let parallelAxis = dataObj.parallelAxis;
let paralleData = dataObj.data;
const parallelAxis = dataObj.parallelAxis;
const paralleData = dataObj.data;
let visualMapObj = {};
if (max === min) {
visualMapObj = {
......@@ -251,7 +307,7 @@ class Para extends React.Component<ParaProps, ParaState> {
color: ['#CA0000', '#FFC400', '#90EE90']
};
}
let optionown = {
const optionown = {
parallelAxis,
tooltip: {
trigger: 'item'
......@@ -288,21 +344,11 @@ class Para extends React.Component<ParaProps, ParaState> {
}
};
// please wait the data
if (this._isMounted) {
if (paralleData.length === 0) {
this.setState({
paraNodata: 'No data'
});
} else {
this.setState({
paraNodata: ''
});
}
}
// draw search space graph
if (this._isMounted) {
this.setState(() => ({
option: optionown
option: optionown,
paraNodata: '',
sutrialCount: paralleData.length
}));
}
}
......@@ -320,6 +366,68 @@ class Para extends React.Component<ParaProps, ParaState> {
this.hyperParaPic(dataSource, expSearchSpace);
}
swapReInit = () => {
const { clickCounts } = this.state;
const val = clickCounts + 1;
if (this._isMounted) {
this.setState({ isLoadConfirm: true, clickCounts: val, });
}
const { paraBack, swapAxisArr } = this.state;
const paralDim = paraBack.parallelAxis;
const paraData = paraBack.data;
let temp: number;
let dim1: number;
let dim2: number;
let bool1: boolean = false;
let bool2: boolean = false;
let bool3: boolean = false;
Object.keys(paralDim).map(item => {
const paral = paralDim[item];
switch (paral.name) {
case swapAxisArr[0]:
dim1 = paral.dim;
bool1 = true;
break;
case swapAxisArr[1]:
dim2 = paral.dim;
bool2 = true;
break;
default:
}
if (bool1 && bool2) {
bool3 = true;
}
});
// swap dim's number
Object.keys(paralDim).map(item => {
if (bool3) {
if (paralDim[item].name === swapAxisArr[0]) {
paralDim[item].dim = dim2;
}
if (paralDim[item].name === swapAxisArr[1]) {
paralDim[item].dim = dim1;
}
}
});
paralDim.sort(this.sortDimY);
// swap data array
Object.keys(paraData).map(paraItem => {
temp = paraData[paraItem][dim1];
paraData[paraItem][dim1] = paraData[paraItem][dim2];
paraData[paraItem][dim2] = temp;
});
this.getOption(paraBack);
// please wait the data
if (this._isMounted) {
this.setState(() => ({
isLoadConfirm: false
}));
}
}
sortDimY = (a: Dimobj, b: Dimobj) => {
return a.dim - b.dim;
}
......@@ -374,11 +482,39 @@ class Para extends React.Component<ParaProps, ParaState> {
});
}
componentDidMount() {
this._isMounted = true;
this.reInit();
}
componentWillReceiveProps(nextProps: ParaProps) {
const dataSource = nextProps.dataSource;
const expSearchSpace = nextProps.expSearchSpace;
const { dataSource, expSearchSpace, whichGraph } = nextProps;
if (whichGraph === '2') {
this.hyperParaPic(dataSource, expSearchSpace);
}
}
shouldComponentUpdate(nextProps: ParaProps, nextState: ParaState) {
const { whichGraph } = nextProps;
const beforeGraph = this.props.whichGraph;
if (whichGraph === '2') {
if (whichGraph !== beforeGraph) {
return true;
}
const { sutrialCount, clickCounts } = nextState;
const beforeCount = this.state.sutrialCount;
const beforeClickCount = this.state.clickCounts;
if (sutrialCount !== beforeCount) {
return true;
}
if (clickCounts !== beforeClickCount) {
return true;
}
}
return false;
}
componentWillUnmount() {
......@@ -386,7 +522,7 @@ class Para extends React.Component<ParaProps, ParaState> {
}
render() {
const { option, paraNodata, dimName } = this.state;
const { option, paraNodata, dimName, isLoadConfirm } = this.state;
return (
<Row className="parameter">
<Row>
......@@ -423,7 +559,8 @@ class Para extends React.Component<ParaProps, ParaState> {
<Button
type="primary"
className="changeBtu tableButton"
onClick={this.reInit}
onClick={this.swapReInit}
disabled={isLoadConfirm}
>
Confirm
</Button>
......@@ -434,7 +571,7 @@ class Para extends React.Component<ParaProps, ParaState> {
<ReactEcharts
option={option}
style={this.chartMulineStyle}
lazyUpdate={true}
// lazyUpdate={true}
notMerge={true} // update now
/>
<div className="noneData">{paraNodata}</div>
......
import * as React from 'react';
import axios from 'axios';
import ReactEcharts from 'echarts-for-react';
import {
Row, Table, Button, Popconfirm, Modal, Checkbox
} from 'antd';
import { Row, Table, Button, Popconfirm, Modal, Checkbox } from 'antd';
const CheckboxGroup = Checkbox.Group;
import { MANAGER_IP, trialJobStatus, COLUMN, COLUMN_INDEX } from '../../static/const';
import { convertDuration, intermediateGraphOption, killJob } from '../../static/function';
import { TableObj, TrialJob } from '../../static/interface';
import OpenRow from '../public-child/OpenRow';
// import DefaultMetric from '../public-child/DefaultMetrc';
import IntermediateVal from '../public-child/IntermediateVal';
import IntermediateVal from '../public-child/IntermediateVal'; // table default metric column
import '../../static/style/search.scss';
require('../../static/style/tableStatus.css');
require('../../static/style/logPath.scss');
......@@ -33,6 +30,7 @@ interface TableListProps {
platform: string;
logCollection: boolean;
isMultiPhase: boolean;
isTableLoading: boolean;
}
interface TableListState {
......@@ -197,7 +195,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
render() {
const { entries, tableSource, updateList } = this.props;
const { entries, tableSource, updateList, isTableLoading } = this.props;
const { intermediateOption, modalVisible, isShowColumn, columnSelected } = this.state;
let showTitle = COLUMN;
let bgColor = '';
......@@ -420,6 +418,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
dataSource={tableSource}
className="commonTableStyle"
pagination={{ pageSize: entries }}
loading={isTableLoading}
/>
{/* Intermediate Result Modal */}
<Modal
......
import axios from 'axios';
import {
message
} from 'antd';
import { message } from 'antd';
import { MANAGER_IP } from './const';
import { FinalResult, FinalType } from './interface';
import { FinalResult, FinalType, TableObj } from './interface';
const convertTime = (num: number) => {
if (num % 3600 === 0) {
......@@ -131,7 +129,16 @@ const killJob = (key: number, id: string, status: string, updateList: Function)
});
};
const filterByStatus = (item: TableObj) => {
return item.status === 'SUCCEEDED';
};
// a waittiong trial may havn't start time
const filterDuration = (item: TableObj) => {
return item.status !== 'WAITING';
};
export {
convertTime, convertDuration, getFinalResult,
getFinal, intermediateGraphOption, killJob
convertTime, convertDuration, getFinalResult, getFinal,
intermediateGraphOption, killJob, filterByStatus, filterDuration
};
......@@ -26,7 +26,7 @@ interface ErrorParameter {
interface Parameters {
parameters: ErrorParameter;
logPath?: string;
intermediate?: Array<number>;
intermediate: Array<number>;
}
interface Experiment {
......
$color: #f2f2f2;
.formatStr{
border: 1px solid #8f8f8f;
color: #333;
padding: 5px 10px;
background-color: #fff;
}
.format {
.ant-modal-header{
background-color: $color;
border-bottom: none;
}
.ant-modal-footer{
background-color: $color;
border-top: none;
}
.ant-modal-body{
background-color: $color;
padding: 10px 24px !important;
}
}
......@@ -52,4 +52,3 @@
.link{
margin-bottom: 10px;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment