Unverified Commit cbf88f79 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #181 from microsoft/master

merge master
parents 8a9b2cb5 2a0fdd3d
...@@ -27,7 +27,6 @@ interface TrialDetailState { ...@@ -27,7 +27,6 @@ interface TrialDetailState {
entriesInSelect: string; entriesInSelect: string;
searchSpace: string; searchSpace: string;
isMultiPhase: boolean; isMultiPhase: boolean;
isTableLoading: boolean;
whichGraph: string; whichGraph: string;
hyperCounts: number; // user click the hyper-parameter counts hyperCounts: number; // user click the hyper-parameter counts
durationCounts: number; durationCounts: number;
...@@ -79,7 +78,6 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -79,7 +78,6 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
whichGraph: '1', whichGraph: '1',
isHasSearch: false, isHasSearch: false,
isMultiPhase: false, isMultiPhase: false,
isTableLoading: false,
hyperCounts: 0, hyperCounts: 0,
durationCounts: 0, durationCounts: 0,
intermediateCounts: 0 intermediateCounts: 0
...@@ -95,9 +93,6 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -95,9 +93,6 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
]) ])
.then(axios.spread((res, res1) => { .then(axios.spread((res, res1) => {
if (res.status === 200 && res1.status === 200) { if (res.status === 200 && res1.status === 200) {
if (this._isMounted === true) {
this.setState(() => ({ isTableLoading: true }));
}
const trialJobs = res.data; const trialJobs = res.data;
const metricSource = res1.data; const metricSource = res1.data;
const trialTable: Array<TableObj> = []; const trialTable: Array<TableObj> = [];
...@@ -187,10 +182,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -187,10 +182,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
} }
} }
if (this._isMounted) { if (this._isMounted) {
this.setState(() => ({ this.setState(() => ({ tableListSource: trialTable }));
isTableLoading: false,
tableListSource: trialTable
}));
} }
if (entriesInSelect === 'all' && this._isMounted) { if (entriesInSelect === 'all' && this._isMounted) {
this.setState(() => ({ this.setState(() => ({
...@@ -330,7 +322,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -330,7 +322,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
const { const {
tableListSource, searchResultSource, isHasSearch, isMultiPhase, tableListSource, searchResultSource, isHasSearch, isMultiPhase,
entriesTable, experimentPlatform, searchSpace, experimentLogCollection, entriesTable, experimentPlatform, searchSpace, experimentLogCollection,
whichGraph, isTableLoading whichGraph
} = this.state; } = this.state;
const source = isHasSearch ? searchResultSource : tableListSource; const source = isHasSearch ? searchResultSource : tableListSource;
return ( return (
...@@ -407,7 +399,6 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { ...@@ -407,7 +399,6 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
<TableList <TableList
entries={entriesTable} entries={entriesTable}
tableSource={source} tableSource={source}
isTableLoading={isTableLoading}
isMultiPhase={isMultiPhase} isMultiPhase={isMultiPhase}
platform={experimentPlatform} platform={experimentPlatform}
updateList={this.getDetailSource} updateList={this.getDetailSource}
......
...@@ -123,7 +123,7 @@ class OpenRow extends React.Component<OpenRowProps, OpenRowState> { ...@@ -123,7 +123,7 @@ class OpenRow extends React.Component<OpenRowProps, OpenRowState> {
<Button <Button
onClick={this.showFormatModal.bind(this, record)} onClick={this.showFormatModal.bind(this, record)}
> >
Copy as python Copy as json
</Button> </Button>
</Row> </Row>
</Row> </Row>
......
...@@ -115,13 +115,13 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState> ...@@ -115,13 +115,13 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
}, },
xAxis: { xAxis: {
type: 'category', type: 'category',
name: 'Scape', name: 'Step',
boundaryGap: false, boundaryGap: false,
data: xAxis data: xAxis
}, },
yAxis: { yAxis: {
type: 'value', type: 'value',
name: 'Intermediate' name: 'metric'
}, },
series: trialIntermediate series: trialIntermediate
}; };
......
...@@ -22,6 +22,7 @@ interface ParaState { ...@@ -22,6 +22,7 @@ interface ParaState {
max: number; // graph color bar limit max: number; // graph color bar limit
min: number; min: number;
sutrialCount: number; // succeed trial numbers for SUC sutrialCount: number; // succeed trial numbers for SUC
succeedRenderCount: number; // all succeed trials number
clickCounts: number; clickCounts: number;
isLoadConfirm: boolean; isLoadConfirm: boolean;
} }
...@@ -68,6 +69,7 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -68,6 +69,7 @@ class Para extends React.Component<ParaProps, ParaState> {
min: 0, min: 0,
max: 1, max: 1,
sutrialCount: 10000000, sutrialCount: 10000000,
succeedRenderCount: 10000000,
clickCounts: 1, clickCounts: 1,
isLoadConfirm: false isLoadConfirm: false
}; };
...@@ -76,7 +78,8 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -76,7 +78,8 @@ class Para extends React.Component<ParaProps, ParaState> {
getParallelAxis = getParallelAxis =
( (
dimName: Array<string>, parallelAxis: Array<Dimobj>, dimName: Array<string>, parallelAxis: Array<Dimobj>,
accPara: Array<number>, eachTrialParams: Array<string> accPara: Array<number>, eachTrialParams: Array<string>,
lengthofTrials: number
) => { ) => {
// get data for every lines. if dim is choice type, number -> toString() // get data for every lines. if dim is choice type, number -> toString()
const paraYdata: number[][] = []; const paraYdata: number[][] = [];
...@@ -120,7 +123,7 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -120,7 +123,7 @@ class Para extends React.Component<ParaProps, ParaState> {
if (swapAxisArr.length >= 2) { if (swapAxisArr.length >= 2) {
this.swapGraph(paraData, swapAxisArr); this.swapGraph(paraData, swapAxisArr);
} }
this.getOption(paraData); this.getOption(paraData, lengthofTrials);
if (this._isMounted === true) { if (this._isMounted === true) {
this.setState(() => ({ paraBack: paraData })); this.setState(() => ({ paraBack: paraData }));
} }
...@@ -159,8 +162,8 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -159,8 +162,8 @@ class Para extends React.Component<ParaProps, ParaState> {
parallelAxis.push({ parallelAxis.push({
dim: i, dim: i,
name: dimName[i], name: dimName[i],
max: searchKey._value[0] - 1, min: searchKey._value[0],
min: 0 max: searchKey._value[1],
}); });
break; break;
...@@ -248,7 +251,8 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -248,7 +251,8 @@ class Para extends React.Component<ParaProps, ParaState> {
this.setState({ this.setState({
paraNodata: 'No data', paraNodata: 'No data',
option: optionOfNull, option: optionOfNull,
sutrialCount: 0 sutrialCount: 0,
succeedRenderCount: 0
}); });
} }
} else { } else {
...@@ -265,7 +269,7 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -265,7 +269,7 @@ class Para extends React.Component<ParaProps, ParaState> {
}); });
if (this._isMounted) { if (this._isMounted) {
this.setState({ max: Math.max(...accPara), min: Math.min(...accPara) }, () => { this.setState({ max: Math.max(...accPara), min: Math.min(...accPara) }, () => {
this.getParallelAxis(dimName, parallelAxis, accPara, eachTrialParams); this.getParallelAxis(dimName, parallelAxis, accPara, eachTrialParams, lenOfDataSource);
}); });
} }
} }
...@@ -283,7 +287,7 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -283,7 +287,7 @@ class Para extends React.Component<ParaProps, ParaState> {
} }
// deal with response data into pic data // deal with response data into pic data
getOption = (dataObj: ParaObj) => { getOption = (dataObj: ParaObj, lengthofTrials: number) => {
// dataObj [[y1], [y2]... [default metric]] // dataObj [[y1], [y2]... [default metric]]
const { max, min } = this.state; const { max, min } = this.state;
const parallelAxis = dataObj.parallelAxis; const parallelAxis = dataObj.parallelAxis;
...@@ -348,6 +352,7 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -348,6 +352,7 @@ class Para extends React.Component<ParaProps, ParaState> {
this.setState(() => ({ this.setState(() => ({
option: optionown, option: optionown,
paraNodata: '', paraNodata: '',
succeedRenderCount: lengthofTrials,
sutrialCount: paralleData.length sutrialCount: paralleData.length
})); }));
} }
...@@ -367,7 +372,7 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -367,7 +372,7 @@ class Para extends React.Component<ParaProps, ParaState> {
} }
swapReInit = () => { swapReInit = () => {
const { clickCounts } = this.state; const { clickCounts, succeedRenderCount } = this.state;
const val = clickCounts + 1; const val = clickCounts + 1;
if (this._isMounted) { if (this._isMounted) {
this.setState({ isLoadConfirm: true, clickCounts: val, }); this.setState({ isLoadConfirm: true, clickCounts: val, });
...@@ -419,7 +424,7 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -419,7 +424,7 @@ class Para extends React.Component<ParaProps, ParaState> {
paraData[paraItem][dim1] = paraData[paraItem][dim2]; paraData[paraItem][dim1] = paraData[paraItem][dim2];
paraData[paraItem][dim2] = temp; paraData[paraItem][dim2] = temp;
}); });
this.getOption(paraBack); this.getOption(paraBack, succeedRenderCount);
// please wait the data // please wait the data
if (this._isMounted) { if (this._isMounted) {
this.setState(() => ({ this.setState(() => ({
...@@ -503,12 +508,16 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -503,12 +508,16 @@ class Para extends React.Component<ParaProps, ParaState> {
return true; return true;
} }
const { sutrialCount, clickCounts } = nextState; const { sutrialCount, clickCounts, succeedRenderCount } = nextState;
const beforeCount = this.state.sutrialCount; const beforeCount = this.state.sutrialCount;
const beforeClickCount = this.state.clickCounts; const beforeClickCount = this.state.clickCounts;
const beforeRealRenderCount = this.state.succeedRenderCount;
if (sutrialCount !== beforeCount) { if (sutrialCount !== beforeCount) {
return true; return true;
} }
if (succeedRenderCount !== beforeRealRenderCount) {
return true;
}
if (clickCounts !== beforeClickCount) { if (clickCounts !== beforeClickCount) {
return true; return true;
......
...@@ -30,7 +30,6 @@ interface TableListProps { ...@@ -30,7 +30,6 @@ interface TableListProps {
platform: string; platform: string;
logCollection: boolean; logCollection: boolean;
isMultiPhase: boolean; isMultiPhase: boolean;
isTableLoading: boolean;
} }
interface TableListState { interface TableListState {
...@@ -195,7 +194,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -195,7 +194,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
render() { render() {
const { entries, tableSource, updateList, isTableLoading } = this.props; const { entries, tableSource, updateList } = this.props;
const { intermediateOption, modalVisible, isShowColumn, columnSelected } = this.state; const { intermediateOption, modalVisible, isShowColumn, columnSelected } = this.state;
let showTitle = COLUMN; let showTitle = COLUMN;
let bgColor = ''; let bgColor = '';
...@@ -264,8 +263,12 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -264,8 +263,12 @@ class TableList extends React.Component<TableListProps, TableListState> {
sorter: (a: TableObj, b: TableObj) => (a.duration as number) - (b.duration as number), sorter: (a: TableObj, b: TableObj) => (a.duration as number) - (b.duration as number),
render: (text: string, record: TableObj) => { render: (text: string, record: TableObj) => {
let duration; let duration;
if (record.duration !== undefined && record.duration > 0) { if (record.duration !== undefined) {
if (record.duration > 0 && record.duration < 1) {
duration = `${record.duration}s`;
} else {
duration = convertDuration(record.duration); duration = convertDuration(record.duration);
}
} else { } else {
duration = 0; duration = 0;
} }
...@@ -418,7 +421,6 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -418,7 +421,6 @@ class TableList extends React.Component<TableListProps, TableListState> {
dataSource={tableSource} dataSource={tableSource}
className="commonTableStyle" className="commonTableStyle"
pagination={{ pageSize: entries }} pagination={{ pageSize: entries }}
loading={isTableLoading}
/> />
{/* Intermediate Result Modal */} {/* Intermediate Result Modal */}
<Modal <Modal
......
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 4
trialConcurrency: 2
tuner:
codeDir: ../../../examples/tuners/random_nas_tuner
classFileName: random_nas_tuner.py
className: RandomNASTuner
trial:
codeDir: ../../../examples/trials/mnist-nas
command: python3 mnist.py --batch_num 100
gpuNum: 0
useAnnotation: true
multiPhase: false
multiThread: false
trainingServicePlatform: local
...@@ -55,6 +55,10 @@ jobs: ...@@ -55,6 +55,10 @@ jobs:
python --version python --version
powershell.exe -file install.ps1 powershell.exe -file install.ps1
displayName: 'Install nni toolkit via source code' displayName: 'Install nni toolkit via source code'
- script: |
set PATH=$(ENV_PATH)
python -m pip install scikit-learn==0.21.0 --user
displayName: 'Install dependencies for integration tests'
- script: | - script: |
cd test cd test
set PATH=$(ENV_PATH) set PATH=$(ENV_PATH)
......
...@@ -47,3 +47,9 @@ jobs: ...@@ -47,3 +47,9 @@ jobs:
runOptions: commands runOptions: commands
commands: python3 /tmp/nnitest/$(Build.BuildId)/nni-remote/test/remote_docker.py --mode stop --name $(Build.BuildId) --os windows commands: python3 /tmp/nnitest/$(Build.BuildId)/nni-remote/test/remote_docker.py --mode stop --name $(Build.BuildId) --os windows
displayName: 'Stop docker' displayName: 'Stop docker'
- task: SSH@0
inputs:
sshEndpoint: $(end_point)
runOptions: commands
commands: sudo rm -rf /tmp/nnitest/$(Build.BuildId)
displayName: 'Clean the remote files'
...@@ -22,9 +22,11 @@ ...@@ -22,9 +22,11 @@
import os import os
import sys import sys
import shutil import shutil
import json
from . import code_generator from . import code_generator
from . import search_space_generator from . import search_space_generator
from . import specific_code_generator
__all__ = ['generate_search_space', 'expand_annotations'] __all__ = ['generate_search_space', 'expand_annotations']
...@@ -74,7 +76,7 @@ def _generate_file_search_space(path, module): ...@@ -74,7 +76,7 @@ def _generate_file_search_space(path, module):
return search_space return search_space
def expand_annotations(src_dir, dst_dir): def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
"""Expand annotations in user code. """Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not. Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str) src_dir: directory path of user code (str)
...@@ -93,11 +95,23 @@ def expand_annotations(src_dir, dst_dir): ...@@ -93,11 +95,23 @@ def expand_annotations(src_dir, dst_dir):
dst_subdir = src_subdir.replace(src_dir, dst_dir, 1) dst_subdir = src_subdir.replace(src_dir, dst_dir, 1)
os.makedirs(dst_subdir, exist_ok=True) os.makedirs(dst_subdir, exist_ok=True)
# generate module name from path
if src_subdir == src_dir:
package = ''
else:
assert src_subdir.startswith(src_dir + slash), src_subdir
prefix_len = len(src_dir) + 1
package = src_subdir[prefix_len:].replace(slash, '.') + '.'
for file_name in files: for file_name in files:
src_path = os.path.join(src_subdir, file_name) src_path = os.path.join(src_subdir, file_name)
dst_path = os.path.join(dst_subdir, file_name) dst_path = os.path.join(dst_subdir, file_name)
if file_name.endswith('.py'): if file_name.endswith('.py'):
if trial_id == '':
annotated |= _expand_file_annotations(src_path, dst_path) annotated |= _expand_file_annotations(src_path, dst_path)
else:
module = package + file_name[:-3]
annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module)
else: else:
shutil.copyfile(src_path, dst_path) shutil.copyfile(src_path, dst_path)
...@@ -121,3 +135,22 @@ def _expand_file_annotations(src_path, dst_path): ...@@ -121,3 +135,22 @@ def _expand_file_annotations(src_path, dst_path):
raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args)) raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
else: else:
raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc)) raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
def _generate_specific_file(src_path, dst_path, exp_id, trial_id, module):
with open(src_path) as src, open(dst_path, 'w') as dst:
try:
with open(os.path.expanduser('~/nni/experiments/%s/trials/%s/parameter.cfg'%(exp_id, trial_id))) as fd:
para_cfg = json.load(fd)
annotated_code = specific_code_generator.parse(src.read(), para_cfg["parameters"], module)
if annotated_code is None:
shutil.copyfile(src_path, dst_path)
return False
dst.write(annotated_code)
return True
except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
else:
raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
...@@ -79,7 +79,7 @@ def parse_annotation_mutable_layers(code, lineno): ...@@ -79,7 +79,7 @@ def parse_annotation_mutable_layers(code, lineno):
fields['optional_inputs'] = True fields['optional_inputs'] = True
elif k.id == 'optional_input_size': elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size' assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num, 'Value of optional_input_size should be a number' assert type(value) is ast.Num or type(value) is ast.List, 'Value of optional_input_size should be a number or list'
optional_input_size = value optional_input_size = value
fields['optional_input_size'] = True fields['optional_input_size'] = True
elif k.id == 'layer_output': elif k.id == 'layer_output':
...@@ -102,13 +102,14 @@ def parse_annotation_mutable_layers(code, lineno): ...@@ -102,13 +102,14 @@ def parse_annotation_mutable_layers(code, lineno):
if fields['fixed_inputs']: if fields['fixed_inputs']:
target_call_args.append(fixed_inputs) target_call_args.append(fixed_inputs)
else: else:
target_call_args.append(ast.NameConstant(value=None)) target_call_args.append(ast.List(elts=[]))
if fields['optional_inputs']: if fields['optional_inputs']:
target_call_args.append(optional_inputs) target_call_args.append(optional_inputs)
assert fields['optional_input_size'], 'optional_input_size must exist when optional_inputs exists' assert fields['optional_input_size'], 'optional_input_size must exist when optional_inputs exists'
target_call_args.append(optional_input_size) target_call_args.append(optional_input_size)
else: else:
target_call_args.append(ast.NameConstant(value=None)) target_call_args.append(ast.Dict(keys=[], values=[]))
target_call_args.append(ast.Num(n=0))
target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[]) target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[])
node = ast.Assign(targets=[layer_output], value=target_call) node = ast.Assign(targets=[layer_output], value=target_call)
nodes.append(node) nodes.append(node)
......
...@@ -54,12 +54,14 @@ class SearchSpaceGenerator(ast.NodeTransformer): ...@@ -54,12 +54,14 @@ class SearchSpaceGenerator(ast.NodeTransformer):
def generate_mutable_layer_search_space(self, args): def generate_mutable_layer_search_space(self, args):
mutable_block = args[0].s mutable_block = args[0].s
mutable_layer = args[1].s mutable_layer = args[1].s
if mutable_block not in self.search_space: key = self.module_name + '/' + mutable_block
self.search_space[mutable_block] = dict() args[0].s = key
self.search_space[mutable_block][mutable_layer] = { if key not in self.search_space:
'layer_choice': [key.s for key in args[2].keys], self.search_space[key] = dict()
'optional_inputs': [key.s for key in args[5].keys], self.search_space[key][mutable_layer] = {
'optional_input_size': args[6].n 'layer_choice': [k.s for k in args[2].keys],
'optional_inputs': [k.s for k in args[5].keys],
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
} }
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import ast
import astor
from nni_cmd.common_utils import print_warning
# pylint: disable=unidiomatic-typecheck
para_cfg = None
prefix_name = None
def parse_annotation_mutable_layers(code, lineno):
"""Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes
code: annotation string (excluding '@')
"""
module = ast.parse(code)
assert type(module) is ast.Module, 'internal error #1'
assert len(module.body) == 1, 'Annotation mutable_layers contains more than one expression'
assert type(module.body[0]) is ast.Expr, 'Annotation is not expression'
call = module.body[0].value
nodes = []
mutable_id = prefix_name + '/mutable_block_' + str(lineno)
mutable_layer_cnt = 0
for arg in call.args:
fields = {'layer_choice': False,
'fixed_inputs': False,
'optional_inputs': False,
'optional_input_size': False,
'layer_output': False}
mutable_layer_id = 'mutable_layer_' + str(mutable_layer_cnt)
mutable_layer_cnt += 1
func_call = None
for k, value in zip(arg.keys, arg.values):
if k.id == 'layer_choice':
assert not fields['layer_choice'], 'Duplicated field: layer_choice'
assert type(value) is ast.List, 'Value of layer_choice should be a list'
for call in value.elts:
assert type(call) is ast.Call, 'Element in layer_choice should be function call'
call_name = astor.to_source(call).strip()
if call_name == para_cfg[mutable_id][mutable_layer_id]['chosen_layer']:
func_call = call
assert not call.args, 'Number of args without keyword should be zero'
break
fields['layer_choice'] = True
elif k.id == 'fixed_inputs':
assert not fields['fixed_inputs'], 'Duplicated field: fixed_inputs'
assert type(value) is ast.List, 'Value of fixed_inputs should be a list'
fixed_inputs = value
fields['fixed_inputs'] = True
elif k.id == 'optional_inputs':
assert not fields['optional_inputs'], 'Duplicated field: optional_inputs'
assert type(value) is ast.List, 'Value of optional_inputs should be a list'
var_names = [astor.to_source(var).strip() for var in value.elts]
chosen_inputs = para_cfg[mutable_id][mutable_layer_id]['chosen_inputs']
elts = []
for i in chosen_inputs:
index = var_names.index(i)
elts.append(value.elts[index])
optional_inputs = ast.List(elts=elts)
fields['optional_inputs'] = True
elif k.id == 'optional_input_size':
pass
elif k.id == 'layer_output':
assert not fields['layer_output'], 'Duplicated field: layer_output'
assert type(value) is ast.Name, 'Value of layer_output should be ast.Name type'
layer_output = value
fields['layer_output'] = True
else:
raise AssertionError('Unexpected field in mutable layer')
# make call for this mutable layer
assert fields['layer_choice'], 'layer_choice must exist'
assert fields['layer_output'], 'layer_output must exist'
if not fields['fixed_inputs']:
fixed_inputs = ast.List(elts=[])
if not fields['optional_inputs']:
optional_inputs = ast.List(elts=[])
inputs = ast.List(elts=[fixed_inputs, optional_inputs])
func_call.args.append(inputs)
node = ast.Assign(targets=[layer_output], value=func_call)
nodes.append(node)
return nodes
def parse_annotation(code):
"""Parse an annotation string.
Return an AST Expr node.
code: annotation string (excluding '@')
"""
module = ast.parse(code)
assert type(module) is ast.Module, 'internal error #1'
assert len(module.body) == 1, 'Annotation contains more than one expression'
assert type(module.body[0]) is ast.Expr, 'Annotation is not expression'
return module.body[0]
def parse_annotation_function(code, func_name):
"""Parse an annotation function.
Return the value of `name` keyword argument and the AST Call node.
func_name: expected function name
"""
expr = parse_annotation(code)
call = expr.value
assert type(call) is ast.Call, 'Annotation is not a function call'
assert type(call.func) is ast.Attribute, 'Unexpected annotation function'
assert type(call.func.value) is ast.Name, 'Invalid annotation function name'
assert call.func.value.id == 'nni', 'Annotation is not a NNI function'
assert call.func.attr == func_name, 'internal error #2'
assert len(call.keywords) == 1, 'Annotation function contains more than one keyword argument'
assert call.keywords[0].arg == 'name', 'Annotation keyword argument is not "name"'
name = call.keywords[0].value
return name, call
def parse_nni_variable(code):
"""Parse `nni.variable` expression.
Return the name argument and AST node of annotated expression.
code: annotation string
"""
name, call = parse_annotation_function(code, 'variable')
assert len(call.args) == 1, 'nni.variable contains more than one arguments'
arg = call.args[0]
assert type(arg) is ast.Call, 'Value of nni.variable is not a function call'
assert type(arg.func) is ast.Attribute, 'nni.variable value is not a NNI function'
assert type(arg.func.value) is ast.Name, 'nni.variable value is not a NNI function'
assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function'
name_str = astor.to_source(name).strip()
keyword_arg = ast.keyword(arg='name', value=ast.Str(s=name_str))
arg.keywords.append(keyword_arg)
if arg.func.attr == 'choice':
convert_args_to_dict(arg)
return name, arg
def parse_nni_function(code):
"""Parse `nni.function_choice` expression.
Return the AST node of annotated expression and a list of dumped function call expressions.
code: annotation string
"""
name, call = parse_annotation_function(code, 'function_choice')
funcs = [ast.dump(func, False) for func in call.args]
convert_args_to_dict(call, with_lambda=True)
name_str = astor.to_source(name).strip()
call.keywords[0].value = ast.Str(s=name_str)
return call, funcs
def convert_args_to_dict(call, with_lambda=False):
"""Convert all args to a dict such that every key and value in the dict is the same as the value of the arg.
Return the AST Call node with only one arg that is the dictionary
"""
keys, values = list(), list()
for arg in call.args:
if type(arg) in [ast.Str, ast.Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
keys.append(arg_value)
values.append(arg)
del call.args[:]
call.args.append(ast.Dict(keys=keys, values=values))
return call
def make_lambda(call):
"""Wrap an AST Call node to lambda expression node.
call: ast.Call node
"""
empty_args = ast.arguments(args=[], vararg=None, kwarg=None, defaults=[])
return ast.Lambda(args=empty_args, body=call)
def test_variable_equal(node1, node2):
"""Test whether two variables are the same."""
if type(node1) is not type(node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in ('lineno', 'col_offset', 'ctx'):
continue
if not test_variable_equal(v, getattr(node2, k)):
return False
return True
if isinstance(node1, list):
if len(node1) != len(node2):
return False
return all(test_variable_equal(n1, n2) for n1, n2 in zip(node1, node2))
return node1 == node2
def replace_variable_node(node, annotation):
"""Replace a node annotated by `nni.variable`.
node: the AST node to replace
annotation: annotation string
"""
assert type(node) is ast.Assign, 'nni.variable is not annotating assignment expression'
assert len(node.targets) == 1, 'Annotated assignment has more than one left-hand value'
name, expr = parse_nni_variable(annotation)
assert test_variable_equal(node.targets[0], name), 'Annotated variable has wrong name'
node.value = expr
return node
def replace_function_node(node, annotation):
"""Replace a node annotated by `nni.function_choice`.
node: the AST node to replace
annotation: annotation string
"""
target, funcs = parse_nni_function(annotation)
FuncReplacer(funcs, target).visit(node)
return node
class FuncReplacer(ast.NodeTransformer):
"""To replace target function call expressions in a node annotated by `nni.function_choice`"""
def __init__(self, funcs, target):
"""Constructor.
funcs: list of dumped function call expressions to replace
target: use this AST node to replace matching expressions
"""
self.funcs = set(funcs)
self.target = target
def visit_Call(self, node): # pylint: disable=invalid-name
if ast.dump(node, False) in self.funcs:
return self.target
return node
class Transformer(ast.NodeTransformer):
"""Transform original code to annotated code"""
def __init__(self):
self.stack = []
self.last_line = 0
self.annotated = False
def visit(self, node):
if isinstance(node, (ast.expr, ast.stmt)):
self.last_line = node.lineno
# do nothing for root
if not self.stack:
return self._visit_children(node)
annotation = self.stack[-1]
# this is a standalone string, may be an annotation
if type(node) is ast.Expr and type(node.value) is ast.Str:
# must not annotate an annotation string
assert annotation is None, 'Annotating an annotation'
return self._visit_string(node)
if annotation is not None: # this expression is annotated
self.stack[-1] = None # so next expression is not
if annotation.startswith('nni.variable'):
return replace_variable_node(node, annotation)
if annotation.startswith('nni.function_choice'):
return replace_function_node(node, annotation)
return self._visit_children(node)
def _visit_string(self, node):
string = node.value.s
if string.startswith('@nni.'):
self.annotated = True
else:
return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code."
print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='Get next parameter here...')], keywords=[]))
if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno)
if string.startswith('@nni.variable') \
or string.startswith('@nni.function_choice'):
self.stack[-1] = string[1:] # mark that the next expression is annotated
return None
raise AssertionError('Unexpected annotation function')
def _visit_children(self, node):
self.stack.append(None)
self.generic_visit(node)
annotation = self.stack.pop()
assert annotation is None, 'Annotation has no target'
return node
def parse(code, para, module):
"""Annotate user code.
Return annotated code (str) if annotation detected; return None if not.
code: original user code (str)
"""
global para_cfg
global prefix_name
para_cfg = para
prefix_name = module
try:
ast_tree = ast.parse(code)
except Exception:
raise RuntimeError('Bad Python code')
transformer = Transformer()
try:
transformer.visit(ast_tree)
except AssertionError as exc:
raise RuntimeError('%d: %s' % (ast_tree.last_line, exc.args[0]))
if not transformer.annotated:
return None
return astor.to_source(ast_tree)
...@@ -83,7 +83,7 @@ def parse_args(): ...@@ -83,7 +83,7 @@ def parse_args():
parser_updater_duration.add_argument('--value', '-v', required=True, help='the unit of time should in {\'s\', \'m\', \'h\', \'d\'}') parser_updater_duration.add_argument('--value', '-v', required=True, help='the unit of time should in {\'s\', \'m\', \'h\', \'d\'}')
parser_updater_duration.set_defaults(func=update_duration) parser_updater_duration.set_defaults(func=update_duration)
parser_updater_trialnum = parser_updater_subparsers.add_parser('trialnum', help='update maxtrialnum') parser_updater_trialnum = parser_updater_subparsers.add_parser('trialnum', help='update maxtrialnum')
parser_updater_trialnum.add_argument('--id', '-i', dest='id', help='the id of experiment') parser_updater_trialnum.add_argument('id', nargs='?', help='the id of experiment')
parser_updater_trialnum.add_argument('--value', '-v', required=True) parser_updater_trialnum.add_argument('--value', '-v', required=True)
parser_updater_trialnum.set_defaults(func=update_trialnum) parser_updater_trialnum.set_defaults(func=update_trialnum)
...@@ -103,6 +103,10 @@ def parse_args(): ...@@ -103,6 +103,10 @@ def parse_args():
parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment') parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_kill.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to be killed') parser_trial_kill.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to be killed')
parser_trial_kill.set_defaults(func=trial_kill) parser_trial_kill.set_defaults(func=trial_kill)
parser_trial_codegen = parser_trial_subparsers.add_parser('codegen', help='generate trial code for a specific trial')
parser_trial_codegen.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_codegen.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to do code generation')
parser_trial_codegen.set_defaults(func=trial_codegen)
#parse experiment command #parse experiment command
parser_experiment = subparsers.add_parser('experiment', help='get experiment information') parser_experiment = subparsers.add_parser('experiment', help='get experiment information')
......
...@@ -25,6 +25,7 @@ import json ...@@ -25,6 +25,7 @@ import json
import datetime import datetime
import time import time
from subprocess import call, check_output from subprocess import call, check_output
from nni_annotation import expand_annotations
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url
from .config_utils import Config, Experiments from .config_utils import Config, Experiments
...@@ -264,7 +265,7 @@ def trial_kill(args): ...@@ -264,7 +265,7 @@ def trial_kill(args):
return return
running, _ = check_rest_server_quick(rest_port) running, _ = check_rest_server_quick(rest_port)
if running: if running:
response = rest_delete(trial_job_id_url(rest_port, args.id), REST_TIME_OUT) response = rest_delete(trial_job_id_url(rest_port, args.trial_id), REST_TIME_OUT)
if response and check_response(response): if response and check_response(response):
print(response.text) print(response.text)
else: else:
...@@ -272,6 +273,17 @@ def trial_kill(args): ...@@ -272,6 +273,17 @@ def trial_kill(args):
else: else:
print_error('Restful server is not running...') print_error('Restful server is not running...')
def trial_codegen(args):
'''Generate code for a specific trial'''
print_warning('Currently, this command is only for nni nas programming interface.')
exp_id = check_experiment_id(args)
nni_config = Config(get_config_filename(args))
if not nni_config.get_config('experimentConfig')['useAnnotation']:
print_error('The experiment is not using annotation')
exit(1)
code_dir = nni_config.get_config('experimentConfig')['trial']['codeDir']
expand_annotations(code_dir, './exp_%s_trial_%s_code'%(exp_id, args.trial_id), exp_id, args.trial_id)
def list_experiment(args): def list_experiment(args):
'''Get experiment information''' '''Get experiment information'''
nni_config = Config(get_config_filename(args)) nni_config = Config(get_config_filename(args))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment