Unverified Commit 1c56fea8 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #21 from microsoft/master

pull code
parents 12410686 97829ccd
import * as React from 'react';
import { Row, Modal } from 'antd';
import ReactEcharts from 'echarts-for-react';
import IntermediateVal from '../public-child/IntermediateVal';
import '../../static/style/compare.scss';
import { TableObj, Intermedia, TooltipForIntermediate } from 'src/static/interface';
// the modal of trial compare
interface CompareProps {
compareRows: Array<TableObj>;
visible: boolean;
cancelFunc: () => void;
}
class Compare extends React.Component<CompareProps, {}> {
public _isCompareMount: boolean;
constructor(props: CompareProps) {
super(props);
}
intermediate = () => {
const { compareRows } = this.props;
const trialIntermediate: Array<Intermedia> = [];
const idsList: Array<string> = [];
Object.keys(compareRows).map(item => {
const temp = compareRows[item];
trialIntermediate.push({
name: temp.id,
data: temp.description.intermediate,
type: 'line',
hyperPara: temp.description.parameters
});
idsList.push(temp.id);
});
// find max intermediate number
trialIntermediate.sort((a, b) => { return (b.data.length - a.data.length); });
const legend: Array<string> = [];
// max length
const length = trialIntermediate[0] !== undefined ? trialIntermediate[0].data.length : 0;
const xAxis: Array<number> = [];
Object.keys(trialIntermediate).map(item => {
const temp = trialIntermediate[item];
legend.push(temp.name);
});
for (let i = 1; i <= length; i++) {
xAxis.push(i);
}
const option = {
tooltip: {
trigger: 'item',
enterable: true,
position: function (point: Array<number>, data: TooltipForIntermediate) {
if (data.dataIndex < length / 2) {
return [point[0], 80];
} else {
return [point[0] - 300, 80];
}
},
formatter: function (data: TooltipForIntermediate) {
const trialId = data.seriesName;
let obj = {};
const temp = trialIntermediate.find(key => key.name === trialId);
if (temp !== undefined) {
obj = temp.hyperPara;
}
return '<div class="tooldetailAccuracy">' +
'<div>Trial ID: ' + trialId + '</div>' +
'<div>Intermediate: ' + data.data + '</div>' +
'<div>Parameters: ' +
'<pre>' + JSON.stringify(obj, null, 4) + '</pre>' +
'</div>' +
'</div>';
}
},
grid: {
left: '5%',
top: 40,
containLabel: true
},
legend: {
data: idsList
},
xAxis: {
type: 'category',
name: 'Step',
boundaryGap: false,
data: xAxis
},
yAxis: {
type: 'value',
name: 'metric'
},
series: trialIntermediate
};
return (
<ReactEcharts
option={option}
style={{ width: '100%', height: 418, margin: '0 auto' }}
notMerge={true} // update now
/>
);
}
// render table column ---
initColumn = () => {
const { compareRows } = this.props;
const idList: Array<string> = [];
const durationList: Array<number> = [];
const parameterList: Array<object> = [];
let parameterKeys: Array<string> = [];
if (compareRows.length !== 0) {
parameterKeys = Object.keys(compareRows[0].description.parameters);
}
Object.keys(compareRows).map(item => {
const temp = compareRows[item];
idList.push(temp.id);
durationList.push(temp.duration);
parameterList.push(temp.description.parameters);
});
return (
<table className="compare">
<tbody>
<tr>
<td />
{Object.keys(idList).map(key => {
return (
<td className="value" key={key}>{idList[key]}</td>
);
})}
</tr>
<tr>
<td className="column">Default metric</td>
{Object.keys(compareRows).map(index => {
const temp = compareRows[index];
return (
<td className="value" key={index}>
<IntermediateVal record={temp}/>
</td>
);
})}
</tr>
<tr>
<td className="column">duration</td>
{Object.keys(durationList).map(index => {
return (
<td className="value" key={index}>{durationList[index]}</td>
);
})}
</tr>
{
Object.keys(parameterKeys).map(index => {
return (
<tr key={index}>
<td className="column" key={index}>{parameterKeys[index]}</td>
{
Object.keys(parameterList).map(key => {
return (
<td key={key} className="value">
{parameterList[key][parameterKeys[index]]}
</td>
);
})
}
</tr>
);
})
}
</tbody>
</table>
);
}
componentDidMount() {
this._isCompareMount = true;
}
componentWillUnmount() {
this._isCompareMount = false;
}
render() {
const { visible, cancelFunc } = this.props;
return (
<Modal
title="Compare trials"
visible={visible}
onCancel={cancelFunc}
footer={null}
destroyOnClose={true}
maskClosable={false}
width="90%"
>
<Row>{this.intermediate()}</Row>
<Row>{this.initColumn()}</Row>
</Modal>
);
}
}
export default Compare;
......@@ -378,7 +378,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
{/* trial table list */}
<Title1 text="Trial jobs" icon="6.png" />
<Row className="allList">
<Col span={12}>
<Col span={10}>
<span>Show</span>
<Select
className="entry"
......@@ -392,26 +392,28 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
</Select>
<span>entries</span>
</Col>
<Col span={12} className="right">
<Row>
<Col span={12}>
<Button
type="primary"
className="tableButton editStyle"
onClick={this.tableList ? this.tableList.addColumn : this.test}
>
Add column
</Button>
</Col>
<Col span={12}>
<Input
type="text"
placeholder="Search by id, trial No. or status"
onChange={this.searchTrial}
style={{ width: 230, marginLeft: 6 }}
/>
</Col>
</Row>
<Col span={14} className="right">
<Button
type="primary"
className="tableButton editStyle"
onClick={this.tableList ? this.tableList.addColumn : this.test}
>
Add column
</Button>
<Button
type="primary"
className="tableButton editStyle mediateBtn"
// use child-component tableList's function, the function is in child-component.
onClick={this.tableList ? this.tableList.compareBtn : this.test}
>
Compare
</Button>
<Input
type="text"
placeholder="Search by id, trial No. or status"
onChange={this.searchTrial}
style={{ width: 230, marginLeft: 6 }}
/>
</Col>
</Row>
<TableList
......
......@@ -72,8 +72,15 @@ class SuccessTable extends React.Component<SuccessTableProps, {}> {
sorter: (a: TableObj, b: TableObj) => (a.duration as number) - (b.duration as number),
render: (text: string, record: TableObj) => {
let duration;
if (record.duration) {
duration = convertDuration(record.duration);
if (record.duration !== undefined) {
// duration is nagative number(-1) & 0-1
if (record.duration > 0 && record.duration < 1 || record.duration < 0) {
duration = `${record.duration}s`;
} else {
duration = convertDuration(record.duration);
}
} else {
duration = 0;
}
return (
<div className="durationsty"><div>{duration}</div></div>
......
import * as React from 'react';
import { Row, Col, Button, Switch } from 'antd';
import { TooltipForIntermediate, TableObj } from '../../static/interface';
import { TooltipForIntermediate, TableObj, Intermedia } from '../../static/interface';
import ReactEcharts from 'echarts-for-react';
require('echarts/lib/component/tooltip');
require('echarts/lib/component/title');
interface Intermedia {
name: string; // id
type: string;
data: Array<number | object>; // intermediate data
hyperPara: object; // each trial hyperpara value
}
interface IntermediateState {
detailSource: Array<TableObj>;
interSource: object;
......
import * as React from 'react';
import axios from 'axios';
import ReactEcharts from 'echarts-for-react';
import { Row, Table, Button, Popconfirm, Modal, Checkbox } from 'antd';
import { Row, Table, Button, Popconfirm, Modal, Checkbox, Select } from 'antd';
const Option = Select.Option;
const CheckboxGroup = Checkbox.Group;
import { MANAGER_IP, trialJobStatus, COLUMN, COLUMN_INDEX } from '../../static/const';
import { convertDuration, intermediateGraphOption, killJob } from '../../static/function';
import { TableObj, TrialJob } from '../../static/interface';
import OpenRow from '../public-child/OpenRow';
import Compare from '../Modal/Compare';
import IntermediateVal from '../public-child/IntermediateVal'; // table default metric column
import '../../static/style/search.scss';
require('../../static/style/tableStatus.css');
......@@ -38,6 +40,12 @@ interface TableListState {
isObjFinal: boolean;
isShowColumn: boolean;
columnSelected: Array<string>; // user select columnKeys
selectRows: Array<TableObj>;
isShowCompareModal: boolean;
selectedRowKeys: string[] | number[];
intermediateData: Array<object>; // a trial's intermediate results (include dict)
intermediateId: string;
intermediateOtherKeys: Array<string>;
}
interface ColumnIndex {
......@@ -50,6 +58,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
public _isMounted = false;
public intervalTrialLog = 10;
public _trialId: string;
public tables: Table<TableObj> | null;
constructor(props: TableListProps) {
super(props);
......@@ -59,7 +68,13 @@ class TableList extends React.Component<TableListProps, TableListState> {
modalVisible: false,
isObjFinal: false,
isShowColumn: false,
columnSelected: COLUMN
isShowCompareModal: false,
columnSelected: COLUMN,
selectRows: [],
selectedRowKeys: [], // close selected trial message after modal closed
intermediateData: [],
intermediateId: '',
intermediateOtherKeys: []
};
}
......@@ -71,7 +86,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
.then(res => {
if (res.status === 200) {
const intermediateArr: number[] = [];
// support intermediate result is dict
// support intermediate result is dict because the last intermediate result is
// final result in a succeed trial, it may be a dict.
// get intermediate result dict keys array
let otherkeys: Array<string> = ['default'];
if (res.data.length !== 0) {
otherkeys = Object.keys(JSON.parse(res.data[0].data));
}
// intermediateArr just store default val
Object.keys(res.data).map(item => {
const temp = JSON.parse(res.data[item].data);
if (typeof temp === 'object') {
......@@ -83,7 +105,10 @@ class TableList extends React.Component<TableListProps, TableListState> {
const intermediate = intermediateGraphOption(intermediateArr, id);
if (this._isMounted) {
this.setState(() => ({
intermediateOption: intermediate
intermediateData: res.data, // store origin intermediate data for a trial
intermediateOption: intermediate,
intermediateOtherKeys: otherkeys,
intermediateId: id
}));
}
}
......@@ -95,6 +120,38 @@ class TableList extends React.Component<TableListProps, TableListState> {
}
}
selectOtherKeys = (value: string) => {
const isShowDefault: boolean = value === 'default' ? true : false;
const { intermediateData, intermediateId } = this.state;
const intermediateArr: number[] = [];
// just watch default key-val
if (isShowDefault === true) {
Object.keys(intermediateData).map(item => {
const temp = JSON.parse(intermediateData[item].data);
if (typeof temp === 'object') {
intermediateArr.push(temp[value]);
} else {
intermediateArr.push(temp);
}
});
} else {
Object.keys(intermediateData).map(item => {
const temp = JSON.parse(intermediateData[item].data);
if (typeof temp === 'object') {
intermediateArr.push(temp[value]);
}
});
}
const intermediate = intermediateGraphOption(intermediateArr, intermediateId);
// re-render
if (this._isMounted) {
this.setState(() => ({
intermediateOption: intermediate
}));
}
}
hideIntermediateModal = () => {
if (this._isMounted) {
this.setState({
......@@ -184,6 +241,31 @@ class TableList extends React.Component<TableListProps, TableListState> {
);
}
fillSelectedRowsTostate = (selected: number[] | string[], selectedRows: Array<TableObj>) => {
if (this._isMounted === true) {
this.setState(() => ({ selectRows: selectedRows, selectedRowKeys: selected }));
}
}
// open Compare-modal
compareBtn = () => {
const { selectRows } = this.state;
if (selectRows.length === 0) {
alert('Please select datas you want to compare!');
} else {
if (this._isMounted === true) {
this.setState({ isShowCompareModal: true });
}
}
}
// close Compare-modal
hideCompareModal = () => {
// close modal. clear select rows data, clear selected track
if (this._isMounted) {
this.setState({ isShowCompareModal: false, selectedRowKeys: [], selectRows: [] });
}
}
componentDidMount() {
this._isMounted = true;
}
......@@ -195,7 +277,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
render() {
const { entries, tableSource, updateList } = this.props;
const { intermediateOption, modalVisible, isShowColumn, columnSelected } = this.state;
const { intermediateOption, modalVisible, isShowColumn, columnSelected,
selectRows, isShowCompareModal, selectedRowKeys, intermediateOtherKeys } = this.state;
const rowSelection = {
selectedRowKeys: selectedRowKeys,
onChange: (selected: string[] | number[], selectedRows: Array<TableObj>) => {
this.fillSelectedRowsTostate(selected, selectedRows);
}
};
let showTitle = COLUMN;
let bgColor = '';
const trialJob: Array<TrialJob> = [];
......@@ -264,7 +353,8 @@ class TableList extends React.Component<TableListProps, TableListState> {
render: (text: string, record: TableObj) => {
let duration;
if (record.duration !== undefined) {
if (record.duration > 0 && record.duration < 1) {
// duration is nagative number(-1) & 0-1
if (record.duration > 0 && record.duration < 1 || record.duration < 0) {
duration = `${record.duration}s`;
} else {
duration = convertDuration(record.duration);
......@@ -416,7 +506,9 @@ class TableList extends React.Component<TableListProps, TableListState> {
<Row className="tableList">
<div id="tableList">
<Table
ref={(table: Table<TableObj> | null) => this.tables = table}
columns={showColumn}
rowSelection={rowSelection}
expandedRowRender={this.openRow}
dataSource={tableSource}
className="commonTableStyle"
......@@ -431,6 +523,27 @@ class TableList extends React.Component<TableListProps, TableListState> {
destroyOnClose={true}
width="80%"
>
{
intermediateOtherKeys.length > 1
?
<Row className="selectKeys">
<Select
className="select"
defaultValue="default"
onSelect={this.selectOtherKeys}
>
{
Object.keys(intermediateOtherKeys).map(item => {
const keys = intermediateOtherKeys[item];
return <Option value={keys} key={item}>{keys}</Option>;
})
}
</Select>
</Row>
:
<div />
}
<ReactEcharts
option={intermediateOption}
style={{
......@@ -457,6 +570,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
className="titleColumn"
/>
</Modal>
<Compare compareRows={selectRows} visible={isShowCompareModal} cancelFunc={this.hideCompareModal} />
</Row>
);
}
......
......@@ -108,10 +108,15 @@ interface FinalResult {
data: string;
}
interface Intermedia {
name: string; // id
type: string;
data: Array<number | object>; // intermediate data
hyperPara: object; // each trial hyperpara value
}
export {
TableObj, Parameters, Experiment,
AccurPoint, TrialNumber, TrialJob,
DetailAccurPoint, TooltipForAccuracy,
ParaObj, Dimobj, FinalResult, FinalType,
TooltipForIntermediate, SearchSpace
TableObj, Parameters, Experiment, AccurPoint, TrialNumber, TrialJob,
DetailAccurPoint, TooltipForAccuracy, ParaObj, Dimobj, FinalResult, FinalType,
TooltipForIntermediate, SearchSpace, Intermedia
};
.compare{
width: 92%;
margin: 0 auto;
color: #333;
tr{
line-height: 30px;
border-bottom: 1px solid #ccc;
}
.column{
width: 124px;
padding-left: 18px;
font-weight: 700;
}
.value{
width: 152px;
padding-right: 18px;
text-align: right;
}
}
/* some buttons about trial-detail table */
/* some buttons in trial-detail table */
.allList{
width: 96%;
margin: 0 auto;
......@@ -31,4 +31,17 @@
}
}
Button.mediateBtn{
margin: 0 32px;
}
/* each row's Intermediate btn -> Modal*/
.selectKeys{
/* intermediate result is dict, select box for keys */
.select{
width: 120px;
float: right;
margin-right: 10%;
}
}
......@@ -102,3 +102,8 @@
.ant-modal-title{
font-size: 20px;
}
/*disable select all checkbox in detail page*/
.ant-table-selection{
display: none;
}
......@@ -22,7 +22,7 @@ class SimpleTuner(Tuner):
self.sig_event = Event()
self.thread_lock = Lock()
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
if self.f_id is None:
self.thread_lock.acquire()
self.f_id = parameter_id
......@@ -50,7 +50,7 @@ class SimpleTuner(Tuner):
self.thread_lock.release()
return self.trial_meta[parameter_id]
def receive_trial_result(self, parameter_id, parameters, reward):
def receive_trial_result(self, parameter_id, parameters, reward, **kwargs):
self.thread_lock.acquire()
if parameter_id == self.f_id:
self.trial_meta[parameter_id]['checksum'] = reward['checksum']
......
......@@ -6,9 +6,9 @@ trialConcurrency: 4
searchSpacePath: ./search_space.json
tuner:
codeDir: ../../../src/sdk/pynni/tests
classFileName: test_multi_phase_tuner.py
className: NaiveMultiPhaseTuner
builtinTunerName: TPE
classArgs:
optimize_mode: maximize
trial:
codeDir: .
......
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 8
trialConcurrency: 4
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: BatchTuner
trial:
codeDir: .
command: python3 multi_phase.py
gpuNum: 0
useAnnotation: false
multiPhase: true
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 8
trialConcurrency: 4
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: Evolution
classArgs:
optimize_mode: maximize
trial:
codeDir: .
command: python3 multi_phase.py
gpuNum: 0
useAnnotation: false
multiPhase: true
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 8
trialConcurrency: 4
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: GridSearch
trial:
codeDir: .
command: python3 multi_phase.py
gpuNum: 0
useAnnotation: false
multiPhase: true
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 8
trialConcurrency: 4
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: MetisTuner
classArgs:
optimize_mode: maximize
trial:
codeDir: .
command: python3 multi_phase.py
gpuNum: 0
useAnnotation: false
multiPhase: true
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 8
trialConcurrency: 4
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: TPE
classArgs:
optimize_mode: maximize
trial:
codeDir: .
command: python3 multi_phase.py
gpuNum: 0
useAnnotation: false
multiPhase: true
multiThread: false
trainingServicePlatform: local
......@@ -6,7 +6,7 @@ class MultiThreadTuner(Tuner):
def __init__(self):
self.parent_done = False
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
if parameter_id == 0:
return {'x': 0}
else:
......@@ -14,7 +14,7 @@ class MultiThreadTuner(Tuner):
time.sleep(2)
return {'x': 1}
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
if parameter_id == 0:
self.parent_done = True
......
......@@ -16,12 +16,12 @@ class NaiveTuner(Tuner):
self.cur = 0
_logger.info('init')
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
self.cur += 1
_logger.info('generate parameters: %s' % self.cur)
return { 'x': self.cur }
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
reward = extract_scalar_reward(value)
_logger.info('receive trial result: %s, %s, %s' % (parameter_id, parameters, reward))
_result.write('%d %d\n' % (parameters['x'], reward))
......
......@@ -56,6 +56,7 @@ common_schema = {
Optional('nniManagerIp'): setType('nniManagerIp', str),
Optional('logDir'): And(os.path.isdir, error=SCHEMA_PATH_ERROR % 'logDir'),
Optional('debug'): setType('debug', bool),
Optional('versionCheck'): setType('versionCheck', bool),
Optional('logLevel'): setChoice('logLevel', 'trace', 'debug', 'info', 'warning', 'error', 'fatal'),
Optional('logCollection'): setChoice('logCollection', 'http', 'none'),
'useAnnotation': setType('useAnnotation', bool),
......
......@@ -303,6 +303,9 @@ def set_experiment(experiment_config, mode, port, config_file_name):
#debug mode should disable version check
if experiment_config.get('debug') is not None:
request_data['versionCheck'] = not experiment_config.get('debug')
#validate version check
if experiment_config.get('versionCheck') is not None:
request_data['versionCheck'] = experiment_config.get('versionCheck')
if experiment_config.get('logCollection'):
request_data['logCollection'] = experiment_config.get('logCollection')
......@@ -363,7 +366,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
exit(1)
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
if log_level not in ['trace', 'debug'] and args.debug:
if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
log_level = 'debug'
# start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], mode, config_file_name, experiment_id, log_dir, log_level)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment