Unverified Commit 1c56fea8 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #21 from microsoft/master

pull code
parents 12410686 97829ccd
import * as React from 'react';
import { Row, Modal } from 'antd';
import ReactEcharts from 'echarts-for-react';
import IntermediateVal from '../public-child/IntermediateVal';
import '../../static/style/compare.scss';
import { TableObj, Intermedia, TooltipForIntermediate } from 'src/static/interface';
// the modal of trial compare
interface CompareProps {
compareRows: Array<TableObj>;
visible: boolean;
cancelFunc: () => void;
}
class Compare extends React.Component<CompareProps, {}> {
public _isCompareMount: boolean;
constructor(props: CompareProps) {
super(props);
}
intermediate = () => {
const { compareRows } = this.props;
const trialIntermediate: Array<Intermedia> = [];
const idsList: Array<string> = [];
Object.keys(compareRows).map(item => {
const temp = compareRows[item];
trialIntermediate.push({
name: temp.id,
data: temp.description.intermediate,
type: 'line',
hyperPara: temp.description.parameters
});
idsList.push(temp.id);
});
// find max intermediate number
trialIntermediate.sort((a, b) => { return (b.data.length - a.data.length); });
const legend: Array<string> = [];
// max length
const length = trialIntermediate[0] !== undefined ? trialIntermediate[0].data.length : 0;
const xAxis: Array<number> = [];
Object.keys(trialIntermediate).map(item => {
const temp = trialIntermediate[item];
legend.push(temp.name);
});
for (let i = 1; i <= length; i++) {
xAxis.push(i);
}
const option = {
tooltip: {
trigger: 'item',
enterable: true,
position: function (point: Array<number>, data: TooltipForIntermediate) {
if (data.dataIndex < length / 2) {
return [point[0], 80];
} else {
return [point[0] - 300, 80];
}
},
formatter: function (data: TooltipForIntermediate) {
const trialId = data.seriesName;
let obj = {};
const temp = trialIntermediate.find(key => key.name === trialId);
if (temp !== undefined) {
obj = temp.hyperPara;
}
return '<div class="tooldetailAccuracy">' +
'<div>Trial ID: ' + trialId + '</div>' +
'<div>Intermediate: ' + data.data + '</div>' +
'<div>Parameters: ' +
'<pre>' + JSON.stringify(obj, null, 4) + '</pre>' +
'</div>' +
'</div>';
}
},
grid: {
left: '5%',
top: 40,
containLabel: true
},
legend: {
data: idsList
},
xAxis: {
type: 'category',
name: 'Step',
boundaryGap: false,
data: xAxis
},
yAxis: {
type: 'value',
name: 'metric'
},
series: trialIntermediate
};
return (
<ReactEcharts
option={option}
style={{ width: '100%', height: 418, margin: '0 auto' }}
notMerge={true} // update now
/>
);
}
// render table column ---
initColumn = () => {
const { compareRows } = this.props;
const idList: Array<string> = [];
const durationList: Array<number> = [];
const parameterList: Array<object> = [];
let parameterKeys: Array<string> = [];
if (compareRows.length !== 0) {
parameterKeys = Object.keys(compareRows[0].description.parameters);
}
Object.keys(compareRows).map(item => {
const temp = compareRows[item];
idList.push(temp.id);
durationList.push(temp.duration);
parameterList.push(temp.description.parameters);
});
return (
<table className="compare">
<tbody>
<tr>
<td />
{Object.keys(idList).map(key => {
return (
<td className="value" key={key}>{idList[key]}</td>
);
})}
</tr>
<tr>
<td className="column">Default metric</td>
{Object.keys(compareRows).map(index => {
const temp = compareRows[index];
return (
<td className="value" key={index}>
<IntermediateVal record={temp}/>
</td>
);
})}
</tr>
<tr>
<td className="column">duration</td>
{Object.keys(durationList).map(index => {
return (
<td className="value" key={index}>{durationList[index]}</td>
);
})}
</tr>
{
Object.keys(parameterKeys).map(index => {
return (
<tr key={index}>
<td className="column" key={index}>{parameterKeys[index]}</td>
{
Object.keys(parameterList).map(key => {
return (
<td key={key} className="value">
{parameterList[key][parameterKeys[index]]}
</td>
);
})
}
</tr>
);
})
}
</tbody>
</table>
);
}
componentDidMount() {
this._isCompareMount = true;
}
componentWillUnmount() {
this._isCompareMount = false;
}
render() {
const { visible, cancelFunc } = this.props;
return (
<Modal
title="Compare trials"
visible={visible}
onCancel={cancelFunc}
footer={null}
destroyOnClose={true}
maskClosable={false}
width="90%"
>
<Row>{this.intermediate()}</Row>
<Row>{this.initColumn()}</Row>
</Modal>
);
}
}
export default Compare;
...@@ -378,7 +378,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -378,7 +378,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
{/* trial table list */} {/* trial table list */}
<Title1 text="Trial jobs" icon="6.png" /> <Title1 text="Trial jobs" icon="6.png" />
<Row className="allList"> <Row className="allList">
<Col span={12}> <Col span={10}>
<span>Show</span> <span>Show</span>
<Select <Select
className="entry" className="entry"
...@@ -392,26 +392,28 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -392,26 +392,28 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
</Select> </Select>
<span>entries</span> <span>entries</span>
</Col> </Col>
<Col span={12} className="right"> <Col span={14} className="right">
<Row> <Button
<Col span={12}> type="primary"
<Button className="tableButton editStyle"
type="primary" onClick={this.tableList ? this.tableList.addColumn : this.test}
className="tableButton editStyle" >
onClick={this.tableList ? this.tableList.addColumn : this.test} Add column
> </Button>
Add column <Button
</Button> type="primary"
</Col> className="tableButton editStyle mediateBtn"
<Col span={12}> // use child-component tableList's function, the function is in child-component.
<Input onClick={this.tableList ? this.tableList.compareBtn : this.test}
type="text" >
placeholder="Search by id, trial No. or status" Compare
onChange={this.searchTrial} </Button>
style={{ width: 230, marginLeft: 6 }} <Input
/> type="text"
</Col> placeholder="Search by id, trial No. or status"
</Row> onChange={this.searchTrial}
style={{ width: 230, marginLeft: 6 }}
/>
</Col> </Col>
</Row> </Row>
<TableList <TableList
......
...@@ -72,8 +72,15 @@ class SuccessTable extends React.Component<SuccessTableProps, {}> { ...@@ -72,8 +72,15 @@ class SuccessTable extends React.Component<SuccessTableProps, {}> {
sorter: (a: TableObj, b: TableObj) => (a.duration as number) - (b.duration as number), sorter: (a: TableObj, b: TableObj) => (a.duration as number) - (b.duration as number),
render: (text: string, record: TableObj) => { render: (text: string, record: TableObj) => {
let duration; let duration;
if (record.duration) { if (record.duration !== undefined) {
duration = convertDuration(record.duration); // duration is nagative number(-1) & 0-1
if (record.duration > 0 && record.duration < 1 || record.duration < 0) {
duration = `${record.duration}s`;
} else {
duration = convertDuration(record.duration);
}
} else {
duration = 0;
} }
return ( return (
<div className="durationsty"><div>{duration}</div></div> <div className="durationsty"><div>{duration}</div></div>
......
import * as React from 'react'; import * as React from 'react';
import { Row, Col, Button, Switch } from 'antd'; import { Row, Col, Button, Switch } from 'antd';
import { TooltipForIntermediate, TableObj } from '../../static/interface'; import { TooltipForIntermediate, TableObj, Intermedia } from '../../static/interface';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
require('echarts/lib/component/tooltip'); require('echarts/lib/component/tooltip');
require('echarts/lib/component/title'); require('echarts/lib/component/title');
interface Intermedia {
name: string; // id
type: string;
data: Array<number | object>; // intermediate data
hyperPara: object; // each trial hyperpara value
}
interface IntermediateState { interface IntermediateState {
detailSource: Array<TableObj>; detailSource: Array<TableObj>;
interSource: object; interSource: object;
......
import * as React from 'react'; import * as React from 'react';
import axios from 'axios'; import axios from 'axios';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
import { Row, Table, Button, Popconfirm, Modal, Checkbox } from 'antd'; import { Row, Table, Button, Popconfirm, Modal, Checkbox, Select } from 'antd';
const Option = Select.Option;
const CheckboxGroup = Checkbox.Group; const CheckboxGroup = Checkbox.Group;
import { MANAGER_IP, trialJobStatus, COLUMN, COLUMN_INDEX } from '../../static/const'; import { MANAGER_IP, trialJobStatus, COLUMN, COLUMN_INDEX } from '../../static/const';
import { convertDuration, intermediateGraphOption, killJob } from '../../static/function'; import { convertDuration, intermediateGraphOption, killJob } from '../../static/function';
import { TableObj, TrialJob } from '../../static/interface'; import { TableObj, TrialJob } from '../../static/interface';
import OpenRow from '../public-child/OpenRow'; import OpenRow from '../public-child/OpenRow';
import Compare from '../Modal/Compare';
import IntermediateVal from '../public-child/IntermediateVal'; // table default metric column import IntermediateVal from '../public-child/IntermediateVal'; // table default metric column
import '../../static/style/search.scss'; import '../../static/style/search.scss';
require('../../static/style/tableStatus.css'); require('../../static/style/tableStatus.css');
...@@ -38,6 +40,12 @@ interface TableListState { ...@@ -38,6 +40,12 @@ interface TableListState {
isObjFinal: boolean; isObjFinal: boolean;
isShowColumn: boolean; isShowColumn: boolean;
columnSelected: Array<string>; // user select columnKeys columnSelected: Array<string>; // user select columnKeys
selectRows: Array<TableObj>;
isShowCompareModal: boolean;
selectedRowKeys: string[] | number[];
intermediateData: Array<object>; // a trial's intermediate results (include dict)
intermediateId: string;
intermediateOtherKeys: Array<string>;
} }
interface ColumnIndex { interface ColumnIndex {
...@@ -50,6 +58,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -50,6 +58,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
public _isMounted = false; public _isMounted = false;
public intervalTrialLog = 10; public intervalTrialLog = 10;
public _trialId: string; public _trialId: string;
public tables: Table<TableObj> | null;
constructor(props: TableListProps) { constructor(props: TableListProps) {
super(props); super(props);
...@@ -59,7 +68,13 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -59,7 +68,13 @@ class TableList extends React.Component<TableListProps, TableListState> {
modalVisible: false, modalVisible: false,
isObjFinal: false, isObjFinal: false,
isShowColumn: false, isShowColumn: false,
columnSelected: COLUMN isShowCompareModal: false,
columnSelected: COLUMN,
selectRows: [],
selectedRowKeys: [], // close selected trial message after modal closed
intermediateData: [],
intermediateId: '',
intermediateOtherKeys: []
}; };
} }
...@@ -71,7 +86,14 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -71,7 +86,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
.then(res => { .then(res => {
if (res.status === 200) { if (res.status === 200) {
const intermediateArr: number[] = []; const intermediateArr: number[] = [];
// support intermediate result is dict // support intermediate result is dict because the last intermediate result is
// final result in a succeed trial, it may be a dict.
// get intermediate result dict keys array
let otherkeys: Array<string> = ['default'];
if (res.data.length !== 0) {
otherkeys = Object.keys(JSON.parse(res.data[0].data));
}
// intermediateArr just store default val
Object.keys(res.data).map(item => { Object.keys(res.data).map(item => {
const temp = JSON.parse(res.data[item].data); const temp = JSON.parse(res.data[item].data);
if (typeof temp === 'object') { if (typeof temp === 'object') {
...@@ -83,7 +105,10 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -83,7 +105,10 @@ class TableList extends React.Component<TableListProps, TableListState> {
const intermediate = intermediateGraphOption(intermediateArr, id); const intermediate = intermediateGraphOption(intermediateArr, id);
if (this._isMounted) { if (this._isMounted) {
this.setState(() => ({ this.setState(() => ({
intermediateOption: intermediate intermediateData: res.data, // store origin intermediate data for a trial
intermediateOption: intermediate,
intermediateOtherKeys: otherkeys,
intermediateId: id
})); }));
} }
} }
...@@ -95,6 +120,38 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -95,6 +120,38 @@ class TableList extends React.Component<TableListProps, TableListState> {
} }
} }
selectOtherKeys = (value: string) => {
const isShowDefault: boolean = value === 'default' ? true : false;
const { intermediateData, intermediateId } = this.state;
const intermediateArr: number[] = [];
// just watch default key-val
if (isShowDefault === true) {
Object.keys(intermediateData).map(item => {
const temp = JSON.parse(intermediateData[item].data);
if (typeof temp === 'object') {
intermediateArr.push(temp[value]);
} else {
intermediateArr.push(temp);
}
});
} else {
Object.keys(intermediateData).map(item => {
const temp = JSON.parse(intermediateData[item].data);
if (typeof temp === 'object') {
intermediateArr.push(temp[value]);
}
});
}
const intermediate = intermediateGraphOption(intermediateArr, intermediateId);
// re-render
if (this._isMounted) {
this.setState(() => ({
intermediateOption: intermediate
}));
}
}
hideIntermediateModal = () => { hideIntermediateModal = () => {
if (this._isMounted) { if (this._isMounted) {
this.setState({ this.setState({
...@@ -184,6 +241,31 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -184,6 +241,31 @@ class TableList extends React.Component<TableListProps, TableListState> {
); );
} }
fillSelectedRowsTostate = (selected: number[] | string[], selectedRows: Array<TableObj>) => {
if (this._isMounted === true) {
this.setState(() => ({ selectRows: selectedRows, selectedRowKeys: selected }));
}
}
// open Compare-modal
compareBtn = () => {
const { selectRows } = this.state;
if (selectRows.length === 0) {
alert('Please select datas you want to compare!');
} else {
if (this._isMounted === true) {
this.setState({ isShowCompareModal: true });
}
}
}
// close Compare-modal
hideCompareModal = () => {
// close modal. clear select rows data, clear selected track
if (this._isMounted) {
this.setState({ isShowCompareModal: false, selectedRowKeys: [], selectRows: [] });
}
}
componentDidMount() { componentDidMount() {
this._isMounted = true; this._isMounted = true;
} }
...@@ -195,7 +277,14 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -195,7 +277,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
render() { render() {
const { entries, tableSource, updateList } = this.props; const { entries, tableSource, updateList } = this.props;
const { intermediateOption, modalVisible, isShowColumn, columnSelected } = this.state; const { intermediateOption, modalVisible, isShowColumn, columnSelected,
selectRows, isShowCompareModal, selectedRowKeys, intermediateOtherKeys } = this.state;
const rowSelection = {
selectedRowKeys: selectedRowKeys,
onChange: (selected: string[] | number[], selectedRows: Array<TableObj>) => {
this.fillSelectedRowsTostate(selected, selectedRows);
}
};
let showTitle = COLUMN; let showTitle = COLUMN;
let bgColor = ''; let bgColor = '';
const trialJob: Array<TrialJob> = []; const trialJob: Array<TrialJob> = [];
...@@ -264,7 +353,8 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -264,7 +353,8 @@ class TableList extends React.Component<TableListProps, TableListState> {
render: (text: string, record: TableObj) => { render: (text: string, record: TableObj) => {
let duration; let duration;
if (record.duration !== undefined) { if (record.duration !== undefined) {
if (record.duration > 0 && record.duration < 1) { // duration is nagative number(-1) & 0-1
if (record.duration > 0 && record.duration < 1 || record.duration < 0) {
duration = `${record.duration}s`; duration = `${record.duration}s`;
} else { } else {
duration = convertDuration(record.duration); duration = convertDuration(record.duration);
...@@ -416,7 +506,9 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -416,7 +506,9 @@ class TableList extends React.Component<TableListProps, TableListState> {
<Row className="tableList"> <Row className="tableList">
<div id="tableList"> <div id="tableList">
<Table <Table
ref={(table: Table<TableObj> | null) => this.tables = table}
columns={showColumn} columns={showColumn}
rowSelection={rowSelection}
expandedRowRender={this.openRow} expandedRowRender={this.openRow}
dataSource={tableSource} dataSource={tableSource}
className="commonTableStyle" className="commonTableStyle"
...@@ -431,6 +523,27 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -431,6 +523,27 @@ class TableList extends React.Component<TableListProps, TableListState> {
destroyOnClose={true} destroyOnClose={true}
width="80%" width="80%"
> >
{
intermediateOtherKeys.length > 1
?
<Row className="selectKeys">
<Select
className="select"
defaultValue="default"
onSelect={this.selectOtherKeys}
>
{
Object.keys(intermediateOtherKeys).map(item => {
const keys = intermediateOtherKeys[item];
return <Option value={keys} key={item}>{keys}</Option>;
})
}
</Select>
</Row>
:
<div />
}
<ReactEcharts <ReactEcharts
option={intermediateOption} option={intermediateOption}
style={{ style={{
...@@ -457,6 +570,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -457,6 +570,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
className="titleColumn" className="titleColumn"
/> />
</Modal> </Modal>
<Compare compareRows={selectRows} visible={isShowCompareModal} cancelFunc={this.hideCompareModal} />
</Row> </Row>
); );
} }
......
...@@ -108,10 +108,15 @@ interface FinalResult { ...@@ -108,10 +108,15 @@ interface FinalResult {
data: string; data: string;
} }
interface Intermedia {
name: string; // id
type: string;
data: Array<number | object>; // intermediate data
hyperPara: object; // each trial hyperpara value
}
export { export {
TableObj, Parameters, Experiment, TableObj, Parameters, Experiment, AccurPoint, TrialNumber, TrialJob,
AccurPoint, TrialNumber, TrialJob, DetailAccurPoint, TooltipForAccuracy, ParaObj, Dimobj, FinalResult, FinalType,
DetailAccurPoint, TooltipForAccuracy, TooltipForIntermediate, SearchSpace, Intermedia
ParaObj, Dimobj, FinalResult, FinalType,
TooltipForIntermediate, SearchSpace
}; };
.compare{
width: 92%;
margin: 0 auto;
color: #333;
tr{
line-height: 30px;
border-bottom: 1px solid #ccc;
}
.column{
width: 124px;
padding-left: 18px;
font-weight: 700;
}
.value{
width: 152px;
padding-right: 18px;
text-align: right;
}
}
/* some buttons about trial-detail table */ /* some buttons in trial-detail table */
.allList{ .allList{
width: 96%; width: 96%;
margin: 0 auto; margin: 0 auto;
...@@ -31,4 +31,17 @@ ...@@ -31,4 +31,17 @@
} }
} }
Button.mediateBtn{
margin: 0 32px;
}
/* each row's Intermediate btn -> Modal*/
.selectKeys{
/* intermediate result is dict, select box for keys */
.select{
width: 120px;
float: right;
margin-right: 10%;
}
}
...@@ -102,3 +102,8 @@ ...@@ -102,3 +102,8 @@
.ant-modal-title{ .ant-modal-title{
font-size: 20px; font-size: 20px;
} }
/*disable select all checkbox in detail page*/
.ant-table-selection{
display: none;
}
...@@ -22,7 +22,7 @@ class SimpleTuner(Tuner): ...@@ -22,7 +22,7 @@ class SimpleTuner(Tuner):
self.sig_event = Event() self.sig_event = Event()
self.thread_lock = Lock() self.thread_lock = Lock()
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
if self.f_id is None: if self.f_id is None:
self.thread_lock.acquire() self.thread_lock.acquire()
self.f_id = parameter_id self.f_id = parameter_id
...@@ -50,7 +50,7 @@ class SimpleTuner(Tuner): ...@@ -50,7 +50,7 @@ class SimpleTuner(Tuner):
self.thread_lock.release() self.thread_lock.release()
return self.trial_meta[parameter_id] return self.trial_meta[parameter_id]
def receive_trial_result(self, parameter_id, parameters, reward): def receive_trial_result(self, parameter_id, parameters, reward, **kwargs):
self.thread_lock.acquire() self.thread_lock.acquire()
if parameter_id == self.f_id: if parameter_id == self.f_id:
self.trial_meta[parameter_id]['checksum'] = reward['checksum'] self.trial_meta[parameter_id]['checksum'] = reward['checksum']
......
...@@ -6,9 +6,9 @@ trialConcurrency: 4 ...@@ -6,9 +6,9 @@ trialConcurrency: 4
searchSpacePath: ./search_space.json searchSpacePath: ./search_space.json
tuner: tuner:
codeDir: ../../../src/sdk/pynni/tests builtinTunerName: TPE
classFileName: test_multi_phase_tuner.py classArgs:
className: NaiveMultiPhaseTuner optimize_mode: maximize
trial: trial:
codeDir: . codeDir: .
......
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 8
trialConcurrency: 4
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: BatchTuner
trial:
codeDir: .
command: python3 multi_phase.py
gpuNum: 0
useAnnotation: false
multiPhase: true
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 8
trialConcurrency: 4
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: Evolution
classArgs:
optimize_mode: maximize
trial:
codeDir: .
command: python3 multi_phase.py
gpuNum: 0
useAnnotation: false
multiPhase: true
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 8
trialConcurrency: 4
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: GridSearch
trial:
codeDir: .
command: python3 multi_phase.py
gpuNum: 0
useAnnotation: false
multiPhase: true
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 8
trialConcurrency: 4
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: MetisTuner
classArgs:
optimize_mode: maximize
trial:
codeDir: .
command: python3 multi_phase.py
gpuNum: 0
useAnnotation: false
multiPhase: true
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 8
trialConcurrency: 4
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: TPE
classArgs:
optimize_mode: maximize
trial:
codeDir: .
command: python3 multi_phase.py
gpuNum: 0
useAnnotation: false
multiPhase: true
multiThread: false
trainingServicePlatform: local
...@@ -6,7 +6,7 @@ class MultiThreadTuner(Tuner): ...@@ -6,7 +6,7 @@ class MultiThreadTuner(Tuner):
def __init__(self): def __init__(self):
self.parent_done = False self.parent_done = False
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
if parameter_id == 0: if parameter_id == 0:
return {'x': 0} return {'x': 0}
else: else:
...@@ -14,7 +14,7 @@ class MultiThreadTuner(Tuner): ...@@ -14,7 +14,7 @@ class MultiThreadTuner(Tuner):
time.sleep(2) time.sleep(2)
return {'x': 1} return {'x': 1}
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
if parameter_id == 0: if parameter_id == 0:
self.parent_done = True self.parent_done = True
......
...@@ -16,12 +16,12 @@ class NaiveTuner(Tuner): ...@@ -16,12 +16,12 @@ class NaiveTuner(Tuner):
self.cur = 0 self.cur = 0
_logger.info('init') _logger.info('init')
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
self.cur += 1 self.cur += 1
_logger.info('generate parameters: %s' % self.cur) _logger.info('generate parameters: %s' % self.cur)
return { 'x': self.cur } return { 'x': self.cur }
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
reward = extract_scalar_reward(value) reward = extract_scalar_reward(value)
_logger.info('receive trial result: %s, %s, %s' % (parameter_id, parameters, reward)) _logger.info('receive trial result: %s, %s, %s' % (parameter_id, parameters, reward))
_result.write('%d %d\n' % (parameters['x'], reward)) _result.write('%d %d\n' % (parameters['x'], reward))
......
...@@ -56,6 +56,7 @@ common_schema = { ...@@ -56,6 +56,7 @@ common_schema = {
Optional('nniManagerIp'): setType('nniManagerIp', str), Optional('nniManagerIp'): setType('nniManagerIp', str),
Optional('logDir'): And(os.path.isdir, error=SCHEMA_PATH_ERROR % 'logDir'), Optional('logDir'): And(os.path.isdir, error=SCHEMA_PATH_ERROR % 'logDir'),
Optional('debug'): setType('debug', bool), Optional('debug'): setType('debug', bool),
Optional('versionCheck'): setType('versionCheck', bool),
Optional('logLevel'): setChoice('logLevel', 'trace', 'debug', 'info', 'warning', 'error', 'fatal'), Optional('logLevel'): setChoice('logLevel', 'trace', 'debug', 'info', 'warning', 'error', 'fatal'),
Optional('logCollection'): setChoice('logCollection', 'http', 'none'), Optional('logCollection'): setChoice('logCollection', 'http', 'none'),
'useAnnotation': setType('useAnnotation', bool), 'useAnnotation': setType('useAnnotation', bool),
......
...@@ -303,6 +303,9 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -303,6 +303,9 @@ def set_experiment(experiment_config, mode, port, config_file_name):
#debug mode should disable version check #debug mode should disable version check
if experiment_config.get('debug') is not None: if experiment_config.get('debug') is not None:
request_data['versionCheck'] = not experiment_config.get('debug') request_data['versionCheck'] = not experiment_config.get('debug')
#validate version check
if experiment_config.get('versionCheck') is not None:
request_data['versionCheck'] = experiment_config.get('versionCheck')
if experiment_config.get('logCollection'): if experiment_config.get('logCollection'):
request_data['logCollection'] = experiment_config.get('logCollection') request_data['logCollection'] = experiment_config.get('logCollection')
...@@ -363,7 +366,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -363,7 +366,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
exit(1) exit(1)
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
if log_level not in ['trace', 'debug'] and args.debug: if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
log_level = 'debug' log_level = 'debug'
# start rest server # start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], mode, config_file_name, experiment_id, log_dir, log_level) rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], mode, config_file_name, experiment_id, log_dir, log_level)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment