Commit cae70729 authored by QuanluZhang's avatar QuanluZhang Committed by xuehui
Browse files

a simple debug tool for general nas programming interface (#1147)



* Dev nas interface -- document (#1049)

* nas interface doc

* Dev nas compile -- code generator (#1067)

* finish code for parsing mutable_layers annotation and testcode

* Dev nas interface -- update figures (#1070)

 update figs

* update searchspace_generator (#1071)

* GeneralNasInterfaces.md: Fix a typo (#1079)
Signed-off-by: default avatarCe Gao <gaoce@caicloud.io>

* add NAS example and fix bugs (#1083)

update searchspace_generator, add example, update NAS example

* fix bugs (#1108)

* nas example

* fix bugs

* remove

* update

* debug

* fix bug

* remove previous mnist.py

* rename

* code gen for specific trial

* fix conflict

* remove print

* add print warning

* update doc

* update doc

* update doc

* remove comment

* update doc

* remove unnecessary global
parent 4465ad8c
# General Programming Interface for Neural Architecture Search # General Programming Interface for Neural Architecture Search (experimental feature)
_*This is an experimental feature, currently, we only implemented the general NAS programming interface. Weight sharing and one-shot NAS based on this programming interface will be supported in the following releases._
Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. However, it takes great efforts to implement those algorithms, and it is hard to reuse code base of one algorithm for implementing another. Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. However, it takes great efforts to implement those algorithms, and it is hard to reuse code base of one algorithm for implementing another.
...@@ -24,6 +26,8 @@ When designing the following model there might be several choices in the fourth ...@@ -24,6 +26,8 @@ When designing the following model there might be several choices in the fourth
There are two ways to write annotation for this example. For the upper one, `input` of the function calls is `[[],[out3]]`. For the bottom one, `input` is `[[out3],[]]`. There are two ways to write annotation for this example. For the upper one, `input` of the function calls is `[[],[out3]]`. For the bottom one, `input` is `[[out3],[]]`.
__Debugging__: We provided an `nnictl trial codegen` command to help debugging your code of NAS programming on NNI. If your trial with trial_id `XXX` in your experiment `YYY` is failed, you could run `nnictl trial codegen YYY --trial_id XXX` to generate an executable code for this trial under your current directory. With this code, you can directly run the trial command without NNI to check why this trial is failed. Basically, this command is to compile your trial code and replace the NNI NAS code with the real chosen layers and inputs.
### Example: choose input connections for a layer ### Example: choose input connections for a layer
Designing connections of layers is critical for making a high performance model. With our provided interface, users could annotate which connections a layer takes (as inputs). They could choose several ones from a set of connections. Below is an example which chooses two inputs from three candidate inputs for `concat`. Here `concat` always takes the output of its previous layer using `fixed_inputs`. Designing connections of layers is critical for making a high performance model. With our provided interface, users could annotate which connections a layer takes (as inputs). They could choose several ones from a set of connections. Below is an example which chooses two inputs from three candidate inputs for `concat`. Here `concat` always takes the output of its previous layer using `fixed_inputs`.
...@@ -92,9 +96,9 @@ NNI's annotation compiler transforms the annotated trial code to the code that c ...@@ -92,9 +96,9 @@ NNI's annotation compiler transforms the annotated trial code to the code that c
The above figure shows how the trial code runs on NNI. `nnictl` processes user trial code to generate a search space file and compiled trial code. The former is fed to tuner, and the latter is used to run trials. The above figure shows how the trial code runs on NNI. `nnictl` processes user trial code to generate a search space file and compiled trial code. The former is fed to tuner, and the latter is used to run trials.
[__TODO__] Simple example of NAS on NNI. [Simple example of NAS on NNI](https://github.com/microsoft/nni/tree/v0.8/examples/trials/mnist-nas).
### Weight sharing ### [__TODO__] Weight sharing
Sharing weights among chosen architectures (i.e., trials) could speedup model search. For example, properly inheriting weights of completed trials could speedup the converge of new trials. One-Shot NAS (e.g., ENAS, Darts) is more aggressive, the training of different architectures (i.e., subgraphs) shares the same copy of the weights in full graph. Sharing weights among chosen architectures (i.e., trials) could speedup model search. For example, properly inheriting weights of completed trials could speedup the converge of new trials. One-Shot NAS (e.g., ENAS, Darts) is more aggressive, the training of different architectures (i.e., subgraphs) shares the same copy of the weights in full graph.
...@@ -102,9 +106,9 @@ Sharing weights among chosen architectures (i.e., trials) could speedup model se ...@@ -102,9 +106,9 @@ Sharing weights among chosen architectures (i.e., trials) could speedup model se
We believe weight sharing (transferring) plays a key role on speeding up NAS, while finding efficient ways of sharing weights is still a hot research topic. We provide a key-value store for users to store and load weights. Tuners and Trials use a provided KV client lib to access the storage. We believe weight sharing (transferring) plays a key role on speeding up NAS, while finding efficient ways of sharing weights is still a hot research topic. We provide a key-value store for users to store and load weights. Tuners and Trials use a provided KV client lib to access the storage.
[__TODO__] Example of weight sharing on NNI. Example of weight sharing on NNI.
### Support of One-Shot NAS ### [__TODO__] Support of One-Shot NAS
One-Shot NAS is a popular approach to find good neural architecture within a limited time and resource budget. Basically, it builds a full graph based on the search space, and uses gradient descent to at last find the best subgraph. There are different training approaches, such as [training subgraphs (per mini-batch)][1], [training full graph through dropout][6], [training with architecture weights (regularization)][3]. Here we focus on the first approach, i.e., training subgraphs (ENAS). One-Shot NAS is a popular approach to find good neural architecture within a limited time and resource budget. Basically, it builds a full graph based on the search space, and uses gradient descent to at last find the best subgraph. There are different training approaches, such as [training subgraphs (per mini-batch)][1], [training full graph through dropout][6], [training with architecture weights (regularization)][3]. Here we focus on the first approach, i.e., training subgraphs (ENAS).
...@@ -114,18 +118,18 @@ With the same annotated trial code, users could choose One-Shot NAS as execution ...@@ -114,18 +118,18 @@ With the same annotated trial code, users could choose One-Shot NAS as execution
The design of One-Shot NAS on NNI is shown in the above figure. One-Shot NAS usually only has one trial job with full graph. NNI supports running multiple such trial jobs each of which runs independently. As One-Shot NAS is not stable, running multiple instances helps find better model. Moreover, trial jobs are also able to synchronize weights during running (i.e., there is only one copy of weights, like asynchroneous parameter-server mode). This may speedup converge. The design of One-Shot NAS on NNI is shown in the above figure. One-Shot NAS usually only has one trial job with full graph. NNI supports running multiple such trial jobs each of which runs independently. As One-Shot NAS is not stable, running multiple instances helps find better model. Moreover, trial jobs are also able to synchronize weights during running (i.e., there is only one copy of weights, like asynchroneous parameter-server mode). This may speedup converge.
[__TODO__] Example of One-Shot NAS on NNI. Example of One-Shot NAS on NNI.
## General tuning algorithms for NAS ## [__TODO__] General tuning algorithms for NAS
Like hyperparameter tuning, a relatively general algorithm for NAS is required. The general programming interface makes this task easier to some extent. We have a RL-based tuner algorithm for NAS from our contributors. We expect efforts from community to design and implement better NAS algorithms. Like hyperparameter tuning, a relatively general algorithm for NAS is required. The general programming interface makes this task easier to some extent. We have a RL-based tuner algorithm for NAS from our contributors. We expect efforts from community to design and implement better NAS algorithms.
[__TODO__] More tuning algorithms for NAS. More tuning algorithms for NAS.
## Export best neural architecture and code ## [__TODO__] Export best neural architecture and code
[__TODO__] After the NNI experiment is done, users could run `nnictl experiment export --code` to export the trial code with the best neural architecture. After the NNI experiment is done, users could run `nnictl experiment export --code` to export the trial code with the best neural architecture.
## Conclusion and Future work ## Conclusion and Future work
......
...@@ -3,4 +3,5 @@ Advanced Features ...@@ -3,4 +3,5 @@ Advanced Features
.. toctree:: .. toctree::
MultiPhase<MultiPhase> MultiPhase<MultiPhase>
AdvancedNas<AdvancedNas> AdvancedNas<AdvancedNas>
\ No newline at end of file NAS Programming Interface<GeneralNasInterfaces>
\ No newline at end of file
...@@ -9,7 +9,6 @@ import time ...@@ -9,7 +9,6 @@ import time
import tensorflow as tf import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni
import operators as op import operators as op
FLAGS = None FLAGS = None
...@@ -215,7 +214,7 @@ def main(params): ...@@ -215,7 +214,7 @@ def main(params):
mnist_network.labels: mnist.test.labels, mnist_network.labels: mnist.test.labels,
mnist_network.keep_prob: 1.0}) mnist_network.keep_prob: 1.0})
nni.report_intermediate_result(test_acc) """@nni.report_intermediate_result(test_acc)"""
logger.debug('test accuracy %g', test_acc) logger.debug('test accuracy %g', test_acc)
logger.debug('Pipe send intermediate result done.') logger.debug('Pipe send intermediate result done.')
...@@ -224,7 +223,7 @@ def main(params): ...@@ -224,7 +223,7 @@ def main(params):
mnist_network.labels: mnist.test.labels, mnist_network.labels: mnist.test.labels,
mnist_network.keep_prob: 1.0}) mnist_network.keep_prob: 1.0})
nni.report_final_result(test_acc) """@nni.report_final_result(test_acc)"""
logger.debug('Final result is %g', test_acc) logger.debug('Final result is %g', test_acc)
logger.debug('Send final result done.') logger.debug('Send final result done.')
......
...@@ -22,9 +22,11 @@ ...@@ -22,9 +22,11 @@
import os import os
import sys import sys
import shutil import shutil
import json
from . import code_generator from . import code_generator
from . import search_space_generator from . import search_space_generator
from . import specific_code_generator
__all__ = ['generate_search_space', 'expand_annotations'] __all__ = ['generate_search_space', 'expand_annotations']
...@@ -74,7 +76,7 @@ def _generate_file_search_space(path, module): ...@@ -74,7 +76,7 @@ def _generate_file_search_space(path, module):
return search_space return search_space
def expand_annotations(src_dir, dst_dir): def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
"""Expand annotations in user code. """Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not. Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str) src_dir: directory path of user code (str)
...@@ -93,11 +95,23 @@ def expand_annotations(src_dir, dst_dir): ...@@ -93,11 +95,23 @@ def expand_annotations(src_dir, dst_dir):
dst_subdir = src_subdir.replace(src_dir, dst_dir, 1) dst_subdir = src_subdir.replace(src_dir, dst_dir, 1)
os.makedirs(dst_subdir, exist_ok=True) os.makedirs(dst_subdir, exist_ok=True)
# generate module name from path
if src_subdir == src_dir:
package = ''
else:
assert src_subdir.startswith(src_dir + slash), src_subdir
prefix_len = len(src_dir) + 1
package = src_subdir[prefix_len:].replace(slash, '.') + '.'
for file_name in files: for file_name in files:
src_path = os.path.join(src_subdir, file_name) src_path = os.path.join(src_subdir, file_name)
dst_path = os.path.join(dst_subdir, file_name) dst_path = os.path.join(dst_subdir, file_name)
if file_name.endswith('.py'): if file_name.endswith('.py'):
annotated |= _expand_file_annotations(src_path, dst_path) if trial_id == '':
annotated |= _expand_file_annotations(src_path, dst_path)
else:
module = package + file_name[:-3]
annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module)
else: else:
shutil.copyfile(src_path, dst_path) shutil.copyfile(src_path, dst_path)
...@@ -121,3 +135,22 @@ def _expand_file_annotations(src_path, dst_path): ...@@ -121,3 +135,22 @@ def _expand_file_annotations(src_path, dst_path):
raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args)) raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
else: else:
raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc)) raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
def _generate_specific_file(src_path, dst_path, exp_id, trial_id, module):
with open(src_path) as src, open(dst_path, 'w') as dst:
try:
with open(os.path.expanduser('~/nni/experiments/%s/trials/%s/parameter.cfg'%(exp_id, trial_id))) as fd:
para_cfg = json.load(fd)
annotated_code = specific_code_generator.parse(src.read(), para_cfg["parameters"], module)
if annotated_code is None:
shutil.copyfile(src_path, dst_path)
return False
dst.write(annotated_code)
return True
except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
else:
raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import ast
import astor
from nni_cmd.common_utils import print_warning
# pylint: disable=unidiomatic-typecheck
para_cfg = None
prefix_name = None
def parse_annotation_mutable_layers(code, lineno):
"""Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes
code: annotation string (excluding '@')
"""
module = ast.parse(code)
assert type(module) is ast.Module, 'internal error #1'
assert len(module.body) == 1, 'Annotation mutable_layers contains more than one expression'
assert type(module.body[0]) is ast.Expr, 'Annotation is not expression'
call = module.body[0].value
nodes = []
mutable_id = prefix_name + '/mutable_block_' + str(lineno)
mutable_layer_cnt = 0
for arg in call.args:
fields = {'layer_choice': False,
'fixed_inputs': False,
'optional_inputs': False,
'optional_input_size': False,
'layer_output': False}
mutable_layer_id = 'mutable_layer_' + str(mutable_layer_cnt)
mutable_layer_cnt += 1
func_call = None
for k, value in zip(arg.keys, arg.values):
if k.id == 'layer_choice':
assert not fields['layer_choice'], 'Duplicated field: layer_choice'
assert type(value) is ast.List, 'Value of layer_choice should be a list'
for call in value.elts:
assert type(call) is ast.Call, 'Element in layer_choice should be function call'
call_name = astor.to_source(call).strip()
if call_name == para_cfg[mutable_id][mutable_layer_id]['chosen_layer']:
func_call = call
assert not call.args, 'Number of args without keyword should be zero'
break
fields['layer_choice'] = True
elif k.id == 'fixed_inputs':
assert not fields['fixed_inputs'], 'Duplicated field: fixed_inputs'
assert type(value) is ast.List, 'Value of fixed_inputs should be a list'
fixed_inputs = value
fields['fixed_inputs'] = True
elif k.id == 'optional_inputs':
assert not fields['optional_inputs'], 'Duplicated field: optional_inputs'
assert type(value) is ast.List, 'Value of optional_inputs should be a list'
var_names = [astor.to_source(var).strip() for var in value.elts]
chosen_inputs = para_cfg[mutable_id][mutable_layer_id]['chosen_inputs']
elts = []
for i in chosen_inputs:
index = var_names.index(i)
elts.append(value.elts[index])
optional_inputs = ast.List(elts=elts)
fields['optional_inputs'] = True
elif k.id == 'optional_input_size':
pass
elif k.id == 'layer_output':
assert not fields['layer_output'], 'Duplicated field: layer_output'
assert type(value) is ast.Name, 'Value of layer_output should be ast.Name type'
layer_output = value
fields['layer_output'] = True
else:
raise AssertionError('Unexpected field in mutable layer')
# make call for this mutable layer
assert fields['layer_choice'], 'layer_choice must exist'
assert fields['layer_output'], 'layer_output must exist'
if not fields['fixed_inputs']:
fixed_inputs = ast.List(elts=[])
if not fields['optional_inputs']:
optional_inputs = ast.List(elts=[])
inputs = ast.List(elts=[fixed_inputs, optional_inputs])
func_call.args.append(inputs)
node = ast.Assign(targets=[layer_output], value=func_call)
nodes.append(node)
return nodes
def parse_annotation(code):
"""Parse an annotation string.
Return an AST Expr node.
code: annotation string (excluding '@')
"""
module = ast.parse(code)
assert type(module) is ast.Module, 'internal error #1'
assert len(module.body) == 1, 'Annotation contains more than one expression'
assert type(module.body[0]) is ast.Expr, 'Annotation is not expression'
return module.body[0]
def parse_annotation_function(code, func_name):
"""Parse an annotation function.
Return the value of `name` keyword argument and the AST Call node.
func_name: expected function name
"""
expr = parse_annotation(code)
call = expr.value
assert type(call) is ast.Call, 'Annotation is not a function call'
assert type(call.func) is ast.Attribute, 'Unexpected annotation function'
assert type(call.func.value) is ast.Name, 'Invalid annotation function name'
assert call.func.value.id == 'nni', 'Annotation is not a NNI function'
assert call.func.attr == func_name, 'internal error #2'
assert len(call.keywords) == 1, 'Annotation function contains more than one keyword argument'
assert call.keywords[0].arg == 'name', 'Annotation keyword argument is not "name"'
name = call.keywords[0].value
return name, call
def parse_nni_variable(code):
"""Parse `nni.variable` expression.
Return the name argument and AST node of annotated expression.
code: annotation string
"""
name, call = parse_annotation_function(code, 'variable')
assert len(call.args) == 1, 'nni.variable contains more than one arguments'
arg = call.args[0]
assert type(arg) is ast.Call, 'Value of nni.variable is not a function call'
assert type(arg.func) is ast.Attribute, 'nni.variable value is not a NNI function'
assert type(arg.func.value) is ast.Name, 'nni.variable value is not a NNI function'
assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function'
name_str = astor.to_source(name).strip()
keyword_arg = ast.keyword(arg='name', value=ast.Str(s=name_str))
arg.keywords.append(keyword_arg)
if arg.func.attr == 'choice':
convert_args_to_dict(arg)
return name, arg
def parse_nni_function(code):
"""Parse `nni.function_choice` expression.
Return the AST node of annotated expression and a list of dumped function call expressions.
code: annotation string
"""
name, call = parse_annotation_function(code, 'function_choice')
funcs = [ast.dump(func, False) for func in call.args]
convert_args_to_dict(call, with_lambda=True)
name_str = astor.to_source(name).strip()
call.keywords[0].value = ast.Str(s=name_str)
return call, funcs
def convert_args_to_dict(call, with_lambda=False):
"""Convert all args to a dict such that every key and value in the dict is the same as the value of the arg.
Return the AST Call node with only one arg that is the dictionary
"""
keys, values = list(), list()
for arg in call.args:
if type(arg) in [ast.Str, ast.Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
keys.append(arg_value)
values.append(arg)
del call.args[:]
call.args.append(ast.Dict(keys=keys, values=values))
return call
def make_lambda(call):
"""Wrap an AST Call node to lambda expression node.
call: ast.Call node
"""
empty_args = ast.arguments(args=[], vararg=None, kwarg=None, defaults=[])
return ast.Lambda(args=empty_args, body=call)
def test_variable_equal(node1, node2):
"""Test whether two variables are the same."""
if type(node1) is not type(node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in ('lineno', 'col_offset', 'ctx'):
continue
if not test_variable_equal(v, getattr(node2, k)):
return False
return True
if isinstance(node1, list):
if len(node1) != len(node2):
return False
return all(test_variable_equal(n1, n2) for n1, n2 in zip(node1, node2))
return node1 == node2
def replace_variable_node(node, annotation):
"""Replace a node annotated by `nni.variable`.
node: the AST node to replace
annotation: annotation string
"""
assert type(node) is ast.Assign, 'nni.variable is not annotating assignment expression'
assert len(node.targets) == 1, 'Annotated assignment has more than one left-hand value'
name, expr = parse_nni_variable(annotation)
assert test_variable_equal(node.targets[0], name), 'Annotated variable has wrong name'
node.value = expr
return node
def replace_function_node(node, annotation):
"""Replace a node annotated by `nni.function_choice`.
node: the AST node to replace
annotation: annotation string
"""
target, funcs = parse_nni_function(annotation)
FuncReplacer(funcs, target).visit(node)
return node
class FuncReplacer(ast.NodeTransformer):
"""To replace target function call expressions in a node annotated by `nni.function_choice`"""
def __init__(self, funcs, target):
"""Constructor.
funcs: list of dumped function call expressions to replace
target: use this AST node to replace matching expressions
"""
self.funcs = set(funcs)
self.target = target
def visit_Call(self, node): # pylint: disable=invalid-name
if ast.dump(node, False) in self.funcs:
return self.target
return node
class Transformer(ast.NodeTransformer):
"""Transform original code to annotated code"""
def __init__(self):
self.stack = []
self.last_line = 0
self.annotated = False
def visit(self, node):
if isinstance(node, (ast.expr, ast.stmt)):
self.last_line = node.lineno
# do nothing for root
if not self.stack:
return self._visit_children(node)
annotation = self.stack[-1]
# this is a standalone string, may be an annotation
if type(node) is ast.Expr and type(node.value) is ast.Str:
# must not annotate an annotation string
assert annotation is None, 'Annotating an annotation'
return self._visit_string(node)
if annotation is not None: # this expression is annotated
self.stack[-1] = None # so next expression is not
if annotation.startswith('nni.variable'):
return replace_variable_node(node, annotation)
if annotation.startswith('nni.function_choice'):
return replace_function_node(node, annotation)
return self._visit_children(node)
def _visit_string(self, node):
string = node.value.s
if string.startswith('@nni.'):
self.annotated = True
else:
return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code."
print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='Get next parameter here...')], keywords=[]))
if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno)
if string.startswith('@nni.variable') \
or string.startswith('@nni.function_choice'):
self.stack[-1] = string[1:] # mark that the next expression is annotated
return None
raise AssertionError('Unexpected annotation function')
def _visit_children(self, node):
self.stack.append(None)
self.generic_visit(node)
annotation = self.stack.pop()
assert annotation is None, 'Annotation has no target'
return node
def parse(code, para, module):
"""Annotate user code.
Return annotated code (str) if annotation detected; return None if not.
code: original user code (str)
"""
global para_cfg
global prefix_name
para_cfg = para
prefix_name = module
try:
ast_tree = ast.parse(code)
except Exception:
raise RuntimeError('Bad Python code')
transformer = Transformer()
try:
transformer.visit(ast_tree)
except AssertionError as exc:
raise RuntimeError('%d: %s' % (ast_tree.last_line, exc.args[0]))
if not transformer.annotated:
return None
return astor.to_source(ast_tree)
...@@ -103,6 +103,10 @@ def parse_args(): ...@@ -103,6 +103,10 @@ def parse_args():
parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment') parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_kill.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to be killed') parser_trial_kill.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to be killed')
parser_trial_kill.set_defaults(func=trial_kill) parser_trial_kill.set_defaults(func=trial_kill)
parser_trial_codegen = parser_trial_subparsers.add_parser('codegen', help='generate trial code for a specific trial')
parser_trial_codegen.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_codegen.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to do code generation')
parser_trial_codegen.set_defaults(func=trial_codegen)
#parse experiment command #parse experiment command
parser_experiment = subparsers.add_parser('experiment', help='get experiment information') parser_experiment = subparsers.add_parser('experiment', help='get experiment information')
......
...@@ -25,6 +25,7 @@ import json ...@@ -25,6 +25,7 @@ import json
import datetime import datetime
import time import time
from subprocess import call, check_output from subprocess import call, check_output
from nni_annotation import expand_annotations
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url
from .config_utils import Config, Experiments from .config_utils import Config, Experiments
...@@ -272,6 +273,17 @@ def trial_kill(args): ...@@ -272,6 +273,17 @@ def trial_kill(args):
else: else:
print_error('Restful server is not running...') print_error('Restful server is not running...')
def trial_codegen(args):
'''Generate code for a specific trial'''
print_warning('Currently, this command is only for nni nas programming interface.')
exp_id = check_experiment_id(args)
nni_config = Config(get_config_filename(args))
if not nni_config.get_config('experimentConfig')['useAnnotation']:
print_error('The experiment is not using annotation')
exit(1)
code_dir = nni_config.get_config('experimentConfig')['trial']['codeDir']
expand_annotations(code_dir, './exp_%s_trial_%s_code'%(exp_id, args.trial_id), exp_id, args.trial_id)
def list_experiment(args): def list_experiment(args):
'''Get experiment information''' '''Get experiment information'''
nni_config = Config(get_config_filename(args)) nni_config = Config(get_config_filename(args))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment