Commit 05913424 authored by suiguoxin's avatar suiguoxin
Browse files

Merge branch 'master' into quniform-tuners

parents e3c8552f 1dab3118
......@@ -193,13 +193,19 @@ class HyperoptTuner(Tuner):
HyperoptTuner is a tuner which using hyperopt algorithm.
"""
def __init__(self, algorithm_name, optimize_mode='minimize'):
def __init__(self, algorithm_name, optimize_mode='minimize',
parallel_optimize=False, constant_liar_type='min'):
"""
Parameters
----------
algorithm_name : str
algorithm_name includes "tpe", "random_search" and anneal".
optimize_mode : str
parallel_optimize : bool
More detail could reference: docs/en_US/Tuner/HyperoptTuner.md
constant_liar_type : str
constant_liar_type including "min", "max" and "mean"
More detail could reference: docs/en_US/Tuner/HyperoptTuner.md
"""
self.algorithm_name = algorithm_name
self.optimize_mode = OptimizeMode(optimize_mode)
......@@ -208,6 +214,13 @@ class HyperoptTuner(Tuner):
self.rval = None
self.supplement_data_num = 0
self.parallel = parallel_optimize
if self.parallel:
self.CL_rval = None
self.constant_liar_type = constant_liar_type
self.running_data = []
self.optimal_y = None
def _choose_tuner(self, algorithm_name):
"""
Parameters
......@@ -269,6 +282,10 @@ class HyperoptTuner(Tuner):
# but it can cause deplicate parameter rarely
total_params = self.get_suggestion(random_search=True)
self.total_data[parameter_id] = total_params
if self.parallel:
self.running_data.append(parameter_id)
params = split_index(total_params)
return params
......@@ -290,10 +307,39 @@ class HyperoptTuner(Tuner):
raise RuntimeError('Received parameter_id not in total_data.')
params = self.total_data[parameter_id]
# code for parallel
if self.parallel:
constant_liar = kwargs.get('constant_liar', False)
if constant_liar:
rval = self.CL_rval
else:
rval = self.rval
self.running_data.remove(parameter_id)
# update the reward of optimal_y
if self.optimal_y is None:
if self.constant_liar_type == 'mean':
self.optimal_y = [reward, 1]
else:
self.optimal_y = reward
else:
if self.constant_liar_type == 'mean':
_sum = self.optimal_y[0] + reward
_number = self.optimal_y[1] + 1
self.optimal_y = [_sum, _number]
elif self.constant_liar_type == 'min':
self.optimal_y = min(self.optimal_y, reward)
elif self.constant_liar_type == 'max':
self.optimal_y = max(self.optimal_y, reward)
logger.debug("Update optimal_y with reward, optimal_y = %s", self.optimal_y)
else:
rval = self.rval
if self.optimize_mode is OptimizeMode.Maximize:
reward = -reward
rval = self.rval
domain = rval.domain
trials = rval.trials
......@@ -378,13 +424,26 @@ class HyperoptTuner(Tuner):
total_params : dict
parameter suggestion
"""
if self.parallel and len(self.total_data)>20 and len(self.running_data) and self.optimal_y is not None:
self.CL_rval = copy.deepcopy(self.rval)
if self.constant_liar_type == 'mean':
_constant_liar_y = self.optimal_y[0] / self.optimal_y[1]
else:
_constant_liar_y = self.optimal_y
for _parameter_id in self.running_data:
self.receive_trial_result(parameter_id=_parameter_id, parameters=None, value=_constant_liar_y, constant_liar=True)
rval = self.CL_rval
random_state = np.random.randint(2**31 - 1)
else:
rval = self.rval
random_state = rval.rstate.randint(2**31 - 1)
trials = rval.trials
algorithm = rval.algo
new_ids = rval.trials.new_trial_ids(1)
rval.trials.refresh()
random_state = rval.rstate.randint(2**31 - 1)
if random_search:
new_trials = hp.rand.suggest(new_ids, rval.domain, trials,
random_state)
......
import * as React from 'react';
import axios from 'axios';
import { MANAGER_IP } from '../static/const';
import { Row, Col, Tabs, Input, Select, Button, Icon } from 'antd';
import { Row, Col, Tabs, Select, Button, Icon } from 'antd';
const Option = Select.Option;
import { TableObj, Parameters } from '../static/interface';
import { TableObj, Parameters, ExperimentInfo } from '../static/interface';
import { getFinal } from '../static/function';
import DefaultPoint from './trial-detail/DefaultMetricPoint';
import Duration from './trial-detail/Duration';
......@@ -13,6 +13,7 @@ import Intermediate from './trial-detail/Intermeidate';
import TableList from './trial-detail/TableList';
const TabPane = Tabs.TabPane;
import '../static/style/trialsDetail.scss';
import '../static/style/search.scss';
interface TrialDetailState {
accSource: object;
......@@ -20,8 +21,6 @@ interface TrialDetailState {
tableListSource: Array<TableObj>;
searchResultSource: Array<TableObj>;
isHasSearch: boolean;
experimentStatus: string;
experimentPlatform: string;
experimentLogCollection: boolean;
entriesTable: number; // table components val
entriesInSelect: string;
......@@ -31,6 +30,9 @@ interface TrialDetailState {
hyperCounts: number; // user click the hyper-parameter counts
durationCounts: number;
intermediateCounts: number;
experimentInfo: ExperimentInfo;
searchFilter: string;
searchPlaceHolder: string;
}
interface TrialsDetailProps {
......@@ -46,6 +48,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
public interAllTableList = 2;
public tableList: TableList | null;
public searchInput: HTMLInputElement | null;
private titleOfacc = (
<Title1 text="Default metric" icon="3.png" />
......@@ -74,8 +77,6 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
accNodata: '',
tableListSource: [],
searchResultSource: [],
experimentStatus: '',
experimentPlatform: '',
experimentLogCollection: false,
entriesTable: 20,
entriesInSelect: '20',
......@@ -85,7 +86,13 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
isMultiPhase: false,
hyperCounts: 0,
durationCounts: 0,
intermediateCounts: 0
intermediateCounts: 0,
experimentInfo: {
platform: '',
optimizeMode: 'maximize'
},
searchFilter: 'id',
searchPlaceHolder: 'Search by id'
};
}
......@@ -212,16 +219,34 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
}));
}
} else {
const { tableListSource } = this.state;
const { tableListSource, searchFilter } = this.state;
const searchResultList: Array<TableObj> = [];
Object.keys(tableListSource).map(key => {
const item = tableListSource[key];
if (item.sequenceId.toString() === targetValue
|| item.id.includes(targetValue)
|| item.status.toUpperCase().includes(targetValue.toUpperCase())
) {
switch (searchFilter) {
case 'id':
if (item.id.toUpperCase().includes(targetValue.toUpperCase())) {
searchResultList.push(item);
}
break;
case 'Trial No.':
if (item.sequenceId.toString() === targetValue) {
searchResultList.push(item);
}
break;
case 'status':
if (item.status.toUpperCase().includes(targetValue.toUpperCase())) {
searchResultList.push(item);
}
break;
case 'parameters':
const strParameters = JSON.stringify(item.description.parameters, null, 4);
if (strParameters.includes(targetValue)) {
searchResultList.push(item);
}
break;
default:
}
});
if (this._isMounted) {
this.setState(() => ({
......@@ -282,6 +307,19 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
alert('TableList component was not properly initialized.');
}
getSearchFilter = (value: string) => {
// clear input value and re-render table
if (this.searchInput !== null) {
this.searchInput.value = '';
if (this._isMounted === true) {
this.setState(() => ({ isHasSearch: false }));
}
}
if (this._isMounted === true) {
this.setState(() => ({ searchFilter: value, searchPlaceHolder: `Search by ${value}` }));
}
}
// get and set logCollection val
checkExperimentPlatform = () => {
axios(`${MANAGER_IP}/experiment`, {
......@@ -289,7 +327,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
})
.then(res => {
if (res.status === 200) {
const trainingPlatform = res.data.params.trainingServicePlatform !== undefined
const trainingPlatform: string = res.data.params.trainingServicePlatform !== undefined
?
res.data.params.trainingServicePlatform
:
......@@ -299,12 +337,24 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
let expLogCollection: boolean = false;
const isMultiy: boolean = res.data.params.multiPhase !== undefined
? res.data.params.multiPhase : false;
const tuner = res.data.params.tuner;
// I'll set optimize is maximize if user not set optimize
let optimize: string = 'maximize';
if (tuner !== undefined) {
if (tuner.classArgs !== undefined) {
if (tuner.classArgs.optimize_mode !== undefined) {
if (tuner.classArgs.optimize_mode === 'minimize') {
optimize = 'minimize';
}
}
}
}
if (logCollection !== undefined && logCollection !== 'none') {
expLogCollection = true;
}
if (this._isMounted) {
this.setState({
experimentPlatform: trainingPlatform,
experimentInfo: { platform: trainingPlatform, optimizeMode: optimize },
searchSpace: res.data.params.searchSpace,
experimentLogCollection: expLogCollection,
isMultiPhase: isMultiy
......@@ -343,8 +393,8 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
const {
tableListSource, searchResultSource, isHasSearch, isMultiPhase,
entriesTable, experimentPlatform, searchSpace, experimentLogCollection,
whichGraph
entriesTable, experimentInfo, searchSpace, experimentLogCollection,
whichGraph, searchPlaceHolder
} = this.state;
const source = isHasSearch ? searchResultSource : tableListSource;
return (
......@@ -354,9 +404,10 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<TabPane tab={this.titleOfacc} key="1">
<Row className="graph">
<DefaultPoint
height={432}
height={402}
showSource={source}
whichGraph={whichGraph}
optimize={experimentInfo.optimizeMode}
/>
</Row>
</TabPane>
......@@ -408,11 +459,19 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
>
Compare
</Button>
<Input
<Select defaultValue="id" className="filter" onSelect={this.getSearchFilter}>
<Option value="id">Id</Option>
<Option value="Trial No.">Trial No.</Option>
<Option value="status">Status</Option>
<Option value="parameters">Parameters</Option>
</Select>
<input
type="text"
placeholder="Search by id, trial No. or status"
className="search-input"
placeholder={searchPlaceHolder}
onChange={this.searchTrial}
style={{ width: 230, marginLeft: 6 }}
style={{ width: 230 }}
ref={text => (this.searchInput) = text}
/>
</Col>
</Row>
......@@ -420,7 +479,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
entries={entriesTable}
tableSource={source}
isMultiPhase={isMultiPhase}
platform={experimentPlatform}
platform={experimentInfo.platform}
updateList={this.getDetailSource}
logCollection={experimentLogCollection}
ref={(tabList) => this.tableList = tabList}
......
import * as React from 'react';
import { Switch } from 'antd';
import ReactEcharts from 'echarts-for-react';
import { filterByStatus } from '../../static/function';
import { TableObj, DetailAccurPoint, TooltipForAccuracy } from '../../static/interface';
......@@ -10,32 +11,36 @@ interface DefaultPointProps {
showSource: Array<TableObj>;
height: number;
whichGraph: string;
optimize: string;
}
interface DefaultPointState {
defaultSource: object;
accNodata: string;
succeedTrials: number;
isViewBestCurve: boolean;
}
class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> {
public _isMounted = false;
public _isDefaultMounted = false;
constructor(props: DefaultPointProps) {
super(props);
this.state = {
defaultSource: {},
accNodata: '',
succeedTrials: 10000000
succeedTrials: 10000000,
isViewBestCurve: false
};
}
defaultMetric = (succeedSource: Array<TableObj>) => {
defaultMetric = (succeedSource: Array<TableObj>, isCurve: boolean) => {
const { optimize } = this.props;
const accSource: Array<DetailAccurPoint> = [];
const showSource: Array<TableObj> = succeedSource.filter(filterByStatus);
const lengthOfSource = showSource.length;
const tooltipDefault = lengthOfSource === 0 ? 'No data' : '';
if (this._isMounted === true) {
if (this._isDefaultMounted === true) {
this.setState(() => ({
succeedTrials: lengthOfSource,
accNodata: tooltipDefault
......@@ -55,34 +60,125 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
type: 'value',
}
};
if (this._isMounted === true) {
if (this._isDefaultMounted === true) {
this.setState(() => ({
defaultSource: nullGraph
}));
}
} else {
const resultList: Array<number | string>[] = [];
const resultList: Array<number | object>[] = [];
const lineListDefault: Array<number> = [];
Object.keys(showSource).map(item => {
const temp = showSource[item];
if (temp.acc !== undefined) {
if (temp.acc.default !== undefined) {
const searchSpace = temp.description.parameters;
lineListDefault.push(temp.acc.default);
accSource.push({
acc: temp.acc.default,
index: temp.sequenceId,
searchSpace: JSON.stringify(searchSpace)
searchSpace: searchSpace
});
}
}
});
// deal with best metric line
const bestCurve: Array<number | object>[] = []; // best curve data source
bestCurve.push([0, lineListDefault[0], accSource[0].searchSpace]); // push the first value
if (optimize === 'maximize') {
for (let i = 1; i < lineListDefault.length; i++) {
const val = lineListDefault[i];
const latest = bestCurve[bestCurve.length - 1][1];
if (val >= latest) {
bestCurve.push([i, val, accSource[i].searchSpace]);
} else {
bestCurve.push([i, latest, accSource[i].searchSpace]);
}
}
} else {
for (let i = 1; i < lineListDefault.length; i++) {
const val = lineListDefault[i];
const latest = bestCurve[bestCurve.length - 1][1];
if (val <= latest) {
bestCurve.push([i, val, accSource[i].searchSpace]);
} else {
bestCurve.push([i, latest, accSource[i].searchSpace]);
}
}
}
Object.keys(accSource).map(item => {
const items = accSource[item];
let temp: Array<number | string>;
temp = [items.index, items.acc, JSON.parse(items.searchSpace)];
let temp: Array<number | object>;
temp = [items.index, items.acc, items.searchSpace];
resultList.push(temp);
});
// isViewBestCurve: false show default metric graph
// isViewBestCurve: true show best curve
if (isCurve === true) {
if (this._isDefaultMounted === true) {
this.setState(() => ({
defaultSource: this.drawBestcurve(bestCurve, resultList)
}));
}
} else {
if (this._isDefaultMounted === true) {
this.setState(() => ({
defaultSource: this.drawDefaultMetric(resultList)
}));
}
}
}
}
const allAcuracy = {
drawBestcurve = (realDefault: Array<number | object>[], resultList: Array<number | object>[]) => {
return {
grid: {
left: '8%'
},
tooltip: {
trigger: 'item',
enterable: true,
position: function (point: Array<number>, data: TooltipForAccuracy) {
if (data.data[0] < realDefault.length / 2) {
return [point[0], 80];
} else {
return [point[0] - 300, 80];
}
},
formatter: function (data: TooltipForAccuracy) {
const result = '<div class="tooldetailAccuracy">' +
'<div>Trial No.: ' + data.data[0] + '</div>' +
'<div>Optimization curve: ' + data.data[1] + '</div>' +
'<div>Parameters: ' +
'<pre>' + JSON.stringify(data.data[2], null, 4) + '</pre>' +
'</div>' +
'</div>';
return result;
}
},
xAxis: {
name: 'Trial',
type: 'category',
},
yAxis: {
name: 'Default metric',
type: 'value',
scale: true
},
series: [{
symbolSize: 6,
type: 'scatter',
data: resultList
}, {
type: 'line',
lineStyle: { color: '#FF6600' },
data: realDefault
}]
};
}
drawDefaultMetric = (resultList: Array<number | object>[]) => {
return {
grid: {
left: '8%'
},
......@@ -114,6 +210,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
yAxis: {
name: 'Default metric',
type: 'value',
scale: true
},
series: [{
symbolSize: 6,
......@@ -121,11 +218,15 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
data: resultList
}]
};
if (this._isMounted === true) {
this.setState(() => ({
defaultSource: allAcuracy
}));
}
loadDefault = (checked: boolean) => {
// checked: true show best metric curve
const { showSource } = this.props;
if (this._isDefaultMounted === true) {
this.defaultMetric(showSource, checked);
// ** deal with data and then update view layer
this.setState(() => ({ isViewBestCurve: checked }));
}
}
......@@ -133,16 +234,21 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
componentWillReceiveProps(nextProps: DefaultPointProps) {
const { whichGraph, showSource } = nextProps;
const { isViewBestCurve } = this.state;
if (whichGraph === '1') {
this.defaultMetric(showSource);
this.defaultMetric(showSource, isViewBestCurve);
}
}
shouldComponentUpdate(nextProps: DefaultPointProps, nextState: DefaultPointState) {
const { whichGraph } = nextProps;
const succTrial = this.state.succeedTrials;
const { succeedTrials } = nextState;
if (whichGraph === '1') {
const { succeedTrials, isViewBestCurve } = nextState;
const succTrial = this.state.succeedTrials;
const isViewBestCurveBefore = this.state.isViewBestCurve;
if (isViewBestCurveBefore !== isViewBestCurve) {
return true;
}
if (succeedTrials !== succTrial) {
return true;
}
......@@ -152,11 +258,11 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}
componentDidMount() {
this._isMounted = true;
this._isDefaultMounted = true;
}
componentWillUnmount() {
this._isMounted = false;
this._isDefaultMounted = false;
}
render() {
......@@ -164,6 +270,12 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
const { defaultSource, accNodata } = this.state;
return (
<div>
<div className="default-metric">
<div className="position">
<span className="bold">optimization curve</span>
<Switch defaultChecked={false} onChange={this.loadDefault} />
</div>
</div>
<ReactEcharts
option={defaultSource}
style={{
......@@ -173,7 +285,6 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}}
theme="my_theme"
notMerge={true} // update now
// lazyUpdate={true}
/>
<div className="showMess">{accNodata}</div>
</div>
......
......@@ -114,7 +114,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
},
yAxis: {
type: 'value',
name: 'metric'
name: 'Metric'
},
series: trialIntermediate
};
......@@ -136,7 +136,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
},
yAxis: {
type: 'value',
name: 'metric'
name: 'Metric'
}
};
if (this._isMounted) {
......@@ -283,9 +283,9 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
{/* style in para.scss */}
<Row className="meline intermediate">
<Col span={8} />
<Col span={3} style={{ height: 34 }}>
<Col span={3} className="inter-filter-btn">
{/* filter message */}
<span>filter</span>
<span>Filter</span>
<Switch
defaultChecked={false}
onChange={this.switchTurn}
......
......@@ -87,13 +87,10 @@ class Para extends React.Component<ParaProps, ParaState> {
let temp: Array<number> = [];
for (let i = 0; i < dimName.length; i++) {
if ('type' in parallelAxis[i]) {
temp.push(
eachTrialParams[item][dimName[i]].toString()
);
temp.push(eachTrialParams[item][dimName[i]].toString());
} else {
temp.push(
eachTrialParams[item][dimName[i]]
);
// default metric
temp.push(eachTrialParams[item][dimName[i]]);
}
}
paraYdata.push(temp);
......@@ -199,11 +196,18 @@ class Para extends React.Component<ParaProps, ParaState> {
break;
// support log distribute
case 'loguniform':
if (lenOfDataSource > 1) {
parallelAxis.push({
dim: i,
name: dimName[i],
type: 'log',
});
} else {
parallelAxis.push({
dim: i,
name: dimName[i]
});
}
break;
default:
......
......@@ -321,9 +321,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
key: 'sequenceId',
width: 120,
className: 'tableHead',
sorter:
(a: TableObj, b: TableObj) =>
(a.sequenceId as number) - (b.sequenceId as number)
sorter: (a: TableObj, b: TableObj) => (a.sequenceId as number) - (b.sequenceId as number)
});
break;
case 'ID':
......
......@@ -59,7 +59,7 @@ interface AccurPoint {
interface DetailAccurPoint {
acc: number;
index: number;
searchSpace: string;
searchSpace: object;
}
interface TooltipForIntermediate {
......@@ -117,8 +117,13 @@ interface Intermedia {
hyperPara: object; // each trial hyperpara value
}
interface ExperimentInfo {
platform: string;
optimizeMode: string;
}
export {
TableObj, Parameters, Experiment, AccurPoint, TrialNumber, TrialJob,
DetailAccurPoint, TooltipForAccuracy, ParaObj, Dimobj, FinalResult, FinalType,
TooltipForIntermediate, SearchSpace, Intermedia
TooltipForIntermediate, SearchSpace, Intermedia, ExperimentInfo
};
......@@ -36,6 +36,10 @@
.strange{
margin-top: 2px;
}
.inter-filter-btn{
height: 34px;
line-height: 34px;
}
.range{
.heng{
margin-left: 6px;
......
......@@ -11,6 +11,24 @@
color: #0071BC;
border-radius: 0;
}
.filter{
width: 100px;
margin-left: 8px;
.ant-select-selection-selected-value{
font-size: 14px;
}
}
.search-input{
height: 32px;
outline: none;
border: 1px solid #d9d9d9;
border-left: none;
padding-left: 8px;
color: #333;
}
.search-input::placeholder{
color: DarkGray;
}
}
.entry{
width: 120px;
......
......@@ -31,14 +31,12 @@
text-align: center;
color:#212121;
font-size: 14px;
/* background-color: #f2f2f2; */
}
th{
padding: 2px;
background-color:white !important;
font-size: 14px;
color: #808080;
border-bottom: 1px solid #d0d0d0;
text-align: center;
}
......@@ -105,3 +103,9 @@
.ant-table-selection{
display: none;
}
/* fix the border-bottom bug in firefox and edge */
.ant-table-thead > tr > th .ant-table-column-sorters::before{
padding-bottom: 25px;
border-bottom: 1px solid #e8e8e8;
}
\ No newline at end of file
......@@ -70,3 +70,16 @@
.allList{
margin-top: 15px;
}
.default-metric{
width: 90%;
text-align: right;
margin-top: 15px;
.position{
color: #333;
.bold{
font-weight: 600;
margin-right: 10px;
}
}
}
This diff is collapsed.
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import sys
import time
import traceback
from utils import GREEN, RED, CLEAR, setup_experiment
def test_nni_cli():
import nnicli as nc
config_file = 'config_test/examples/mnist.test.yml'
try:
# Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict
time.sleep(6)
print(GREEN + 'Testing nnicli:' + config_file + CLEAR)
nc.start_nni(config_file)
time.sleep(3)
nc.set_endpoint('http://localhost:8080')
print(nc.version())
print(nc.get_job_statistics())
print(nc.get_experiment_status())
nc.list_trial_jobs()
print(GREEN + 'Test nnicli {}: TEST PASS'.format(config_file) + CLEAR)
except Exception as error:
print(RED + 'Test nnicli {}: TEST FAIL'.format(config_file) + CLEAR)
print('%r' % error)
traceback.print_exc()
raise error
finally:
nc.stop_nni()
if __name__ == '__main__':
installed = (sys.argv[-1] != '--preinstall')
setup_experiment(installed)
test_nni_cli()
......@@ -88,6 +88,7 @@ def stop_experiment_test():
subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml', '--port', '8080'], check=True)
subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml', '--port', '8888'], check=True)
subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml', '--port', '8989'], check=True)
subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml', '--port', '8990'], check=True)
# test cmd 'nnictl stop id`
experiment_id = get_experiment_id(EXPERIMENT_URL)
......@@ -96,6 +97,12 @@ def stop_experiment_test():
snooze()
assert not detect_port(8080), '`nnictl stop %s` failed to stop experiments' % experiment_id
# test cmd `nnictl stop --port`
proc = subprocess.run(['nnictl', 'stop', '--port', '8990'])
assert proc.returncode == 0, '`nnictl stop %s` failed with code %d' % (experiment_id, proc.returncode)
snooze()
assert not detect_port(8990), '`nnictl stop %s` failed to stop experiments' % experiment_id
# test cmd `nnictl stop all`
proc = subprocess.run(['nnictl', 'stop', 'all'])
assert proc.returncode == 0, '`nnictl stop all` failed with code %d' % proc.returncode
......
......@@ -36,3 +36,7 @@ jobs:
cd test
python metrics_test.py
displayName: 'Trial job metrics test'
- script: |
cd test
PATH=$HOME/.local/bin:$PATH python3 cli_test.py
displayName: 'nnicli test'
......@@ -37,3 +37,7 @@ jobs:
cd test
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
displayName: 'Trial job metrics test'
- script: |
cd test
PATH=$HOME/.local/bin:$PATH python3 cli_test.py
displayName: 'nnicli test'
......@@ -57,8 +57,8 @@ class SearchSpaceGenerator(ast.NodeTransformer):
key = self.module_name + '/' + mutable_block
args[0].s = key
if key not in self.search_space:
self.search_space[key] = dict()
self.search_space[key][mutable_layer] = {
self.search_space[key] = {'_type': 'mutable_layer', '_value': {}}
self.search_space[key]['_value'][mutable_layer] = {
'layer_choice': [k.s for k in args[2].keys],
'optional_inputs': [k.s for k in args[5].keys],
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
......
......@@ -44,8 +44,9 @@ class AnnotationTestCase(TestCase):
self.assertEqual(search_space, json.load(f))
def test_code_generator(self):
code_dir = expand_annotations('testcase/usercode', '_generated')
code_dir = expand_annotations('testcase/usercode', '_generated', nas_mode='classic_mode')
self.assertEqual(code_dir, '_generated')
self._assert_source_equal('testcase/annotated/nas.py', '_generated/nas.py')
self._assert_source_equal('testcase/annotated/mnist.py', '_generated/mnist.py')
self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/dir/simple.py')
with open('testcase/usercode/nonpy.txt') as src, open('_generated/nonpy.txt') as dst:
......
import nni
import time
def add_one(inputs):
return inputs + 1
def add_two(inputs):
return inputs + 2
def add_three(inputs):
return inputs + 3
def add_four(inputs):
return inputs + 4
def main():
images = 5
layer_1_out = nni.mutable_layer('mutable_block_39', 'mutable_layer_0',
{'add_one()': add_one, 'add_two()': add_two, 'add_three()':
add_three, 'add_four()': add_four}, {'add_one()': {}, 'add_two()':
{}, 'add_three()': {}, 'add_four()': {}}, [], {'images': images}, 1,
'classic_mode')
layer_2_out = nni.mutable_layer('mutable_block_39', 'mutable_layer_1',
{'add_one()': add_one, 'add_two()': add_two, 'add_three()':
add_three, 'add_four()': add_four}, {'add_one()': {}, 'add_two()':
{}, 'add_three()': {}, 'add_four()': {}}, [], {'layer_1_out':
layer_1_out}, 1, 'classic_mode')
layer_3_out = nni.mutable_layer('mutable_block_39', 'mutable_layer_2',
{'add_one()': add_one, 'add_two()': add_two, 'add_three()':
add_three, 'add_four()': add_four}, {'add_one()': {}, 'add_two()':
{}, 'add_three()': {}, 'add_four()': {}}, [], {'layer_1_out':
layer_1_out, 'layer_2_out': layer_2_out}, 1, 'classic_mode')
nni.report_intermediate_result(layer_1_out)
time.sleep(2)
nni.report_intermediate_result(layer_2_out)
time.sleep(2)
nni.report_intermediate_result(layer_3_out)
time.sleep(2)
layer_3_out = layer_3_out + 10
nni.report_final_result(layer_3_out)
if __name__ == '__main__':
main()
......@@ -143,5 +143,47 @@
"(2 * 3 + 4)",
"(lambda x: 1 + x)"
]
},
"nas/mutable_block_39": {
"_type": "mutable_layer",
"_value": {
"mutable_layer_0": {
"layer_choice": [
"add_one()",
"add_two()",
"add_three()",
"add_four()"
],
"optional_inputs": [
"images"
],
"optional_input_size": 1
},
"mutable_layer_1": {
"layer_choice": [
"add_one()",
"add_two()",
"add_three()",
"add_four()"
],
"optional_inputs": [
"layer_1_out"
],
"optional_input_size": 1
},
"mutable_layer_2": {
"layer_choice": [
"add_one()",
"add_two()",
"add_three()",
"add_four()"
],
"optional_inputs": [
"layer_1_out",
"layer_2_out"
],
"optional_input_size": 1
}
}
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment