Commit 0619de20 authored by liuzhe-lz's avatar liuzhe-lz Committed by QuanluZhang
Browse files

Improve annotation (#138)

* Improve annotation

* Minor bugfix
parent 70be7d0f
This diff is collapsed.
......@@ -124,6 +124,6 @@ else:
del frame # see official doc
module = inspect.getmodulename(filename)
if name is None:
name = '#{:d}'.format(lineno)
name = '__line{:d}'.format(lineno)
key = '{}/{}/{}'.format(module, name, func)
return trial.get_parameter(key)
......@@ -33,9 +33,9 @@ class SmartParamTestCase(TestCase):
def setUp(self):
params = {
'test_smartparam/choice1/choice': 2,
'test_smartparam/#{:d}/uniform'.format(lineno1): '5',
'test_smartparam/__line{:d}/uniform'.format(lineno1): '5',
'test_smartparam/func/function_choice': 1,
'test_smartparam/#{:d}/function_choice'.format(lineno2): 0
'test_smartparam/__line{:d}/function_choice'.format(lineno2): 0
}
nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params }
......
This diff is collapsed.
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -69,6 +69,7 @@ def _generate_file_search_space(path, module):
def expand_annotations(src_dir, dst_dir):
"""Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str)
dst_dir: directory to place generated files (str)
"""
......@@ -77,6 +78,8 @@ def expand_annotations(src_dir, dst_dir):
if dst_dir[-1] == '/':
dst_dir = dst_dir[:-1]
annotated = False
for src_subdir, dirs, files in os.walk(src_dir):
assert src_subdir.startswith(src_dir)
dst_subdir = src_subdir.replace(src_dir, dst_dir, 1)
......@@ -86,17 +89,25 @@ def expand_annotations(src_dir, dst_dir):
src_path = os.path.join(src_subdir, file_name)
dst_path = os.path.join(dst_subdir, file_name)
if file_name.endswith('.py'):
_expand_file_annotations(src_path, dst_path)
annotated |= _expand_file_annotations(src_path, dst_path)
else:
shutil.copyfile(src_path, dst_path)
for dir_name in dirs:
os.makedirs(os.path.join(dst_subdir, dir_name), exist_ok=True)
return dst_dir if annotated else src_dir
def _expand_file_annotations(src_path, dst_path):
with open(src_path) as src, open(dst_path, 'w') as dst:
try:
dst.write(code_generator.parse(src.read()))
annotated_code = code_generator.parse(src.read())
if annotated_code is None:
shutil.copyfile(src_path, dst_path)
return False
dst.write(annotated_code)
return True
except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(src_path + ' ' + '\n'.join(exc.args))
......
......@@ -161,6 +161,7 @@ class Transformer(ast.NodeTransformer):
def __init__(self):
self.stack = []
self.last_line = 0
self.annotated = False
def visit(self, node):
if isinstance(node, (ast.expr, ast.stmt)):
......@@ -190,8 +191,9 @@ class Transformer(ast.NodeTransformer):
def _visit_string(self, node):
string = node.value.s
if not string.startswith('@nni.'):
if string.startswith('@nni.'):
self.annotated = True
else:
return node # not an annotation, ignore it
if string.startswith('@nni.report_intermediate_result(') \
......@@ -216,7 +218,7 @@ class Transformer(ast.NodeTransformer):
def parse(code):
"""Annotate user code.
Return annotated code (str).
Return annotated code (str) if annotation detected; return None if not.
code: original user code (str)
"""
try:
......@@ -224,11 +226,15 @@ def parse(code):
except Exception:
raise RuntimeError('Bad Python code')
transformer = Transformer()
try:
Transformer().visit(ast_tree)
transformer.visit(ast_tree)
except AssertionError as exc:
raise RuntimeError('%d: %s' % (ast_tree.last_line, exc.args[0]))
if not transformer.annotated:
return None
last_future_import = -1
import_nni = ast.Import(names=[ast.alias(name='nni', asname=None)])
nodes = ast_tree.body
......
......@@ -76,7 +76,7 @@ class SearchSpaceGenerator(ast.NodeVisitor):
else:
# generate the missing name automatically
assert len(node.args) > 0, 'Smart parameter expression has no argument'
name = '#' + str(node.args[-1].lineno)
name = '__line' + str(node.args[-1].lineno)
specified_name = False
if func in ('choice', 'function_choice'):
......
......@@ -27,6 +27,7 @@ import ast
import json
import os
import shutil
import tempfile
from unittest import TestCase, main
......@@ -43,12 +44,18 @@ class AnnotationTestCase(TestCase):
self.assertEqual(search_space, json.load(f))
def test_code_generator(self):
expand_annotations('testcase/usercode', '_generated')
code_dir = expand_annotations('testcase/usercode', '_generated')
self.assertEqual(code_dir, '_generated')
self._assert_source_equal('testcase/annotated/mnist.py', '_generated/mnist.py')
self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/dir/simple.py')
with open('testcase/usercode/nonpy.txt') as src, open('_generated/nonpy.txt') as dst:
assert src.read() == dst.read()
def test_annotation_detecting(self):
dir_ = 'testcase/usercode/non_annotation'
code_dir = expand_annotations(dir_, tempfile.mkdtemp())
self.assertEqual(code_dir, dir_)
def _assert_source_equal(self, src1, src2):
with open(src1) as f1, open(src2) as f2:
ast1 = ast.dump(ast.parse(f1.read()))
......
import nni
def bar():
"""I'm doc string"""
return nni.report_final_result(0)
......@@ -3,15 +3,15 @@
"_type": "choice",
"_value": [ 0, 1, 2, 3 ]
},
"handwrite/#5/function_choice": {
"handwrite/__line5/function_choice": {
"_type": "choice",
"_value": [ 0, 1, 2 ]
},
"handwrite/#8/qlognormal": {
"handwrite/__line8/qlognormal": {
"_type": "qlognormal",
"_value": [ 1.2, 3, 4.5 ]
},
"handwrite/#13/choice": {
"handwrite/__line13/choice": {
"_type": "choice",
"_value": [ 0, 1 ]
},
......
import nni
def bar():
"""I'm doc string"""
return nni.report_final_result(0)
......@@ -26,7 +26,6 @@ import string
from subprocess import Popen, PIPE, call
import tempfile
from nni_annotation import *
import random
from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response
from .url_utils import cluster_metadata_url, experiment_url
......@@ -189,13 +188,11 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No
nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation
if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), 'nni', 'annotation', ''.join(random.sample(string.ascii_letters + string.digits, 8)))
if os.path.isdir(path):
shutil.rmtree(path)
os.makedirs(path)
expand_annotations(experiment_config['trial']['codeDir'], path)
experiment_config['trial']['codeDir'] = path
search_space = generate_search_space(experiment_config['trial']['codeDir'])
path = os.path.join(tempfile.gettempdir(), 'nni', 'annotation')
path = tempfile.mkdtemp(dir=path)
code_dir = expand_annotations(experiment_config['trial']['codeDir'], path)
experiment_config['trial']['codeDir'] = code_dir
search_space = generate_search_space(code_dir)
experiment_config['searchSpace'] = json.dumps(search_space)
assert search_space, ERROR_INFO % 'Generated search space is empty'
elif experiment_config.get('searchSpacePath'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment