Commit 0619de20 authored by liuzhe-lz's avatar liuzhe-lz Committed by QuanluZhang
Browse files

Improve annotation (#138)

* Improve annotation

* Minor bugfix
parent 70be7d0f
This diff is collapsed.
...@@ -124,6 +124,6 @@ else: ...@@ -124,6 +124,6 @@ else:
del frame # see official doc del frame # see official doc
module = inspect.getmodulename(filename) module = inspect.getmodulename(filename)
if name is None: if name is None:
name = '#{:d}'.format(lineno) name = '__line{:d}'.format(lineno)
key = '{}/{}/{}'.format(module, name, func) key = '{}/{}/{}'.format(module, name, func)
return trial.get_parameter(key) return trial.get_parameter(key)
...@@ -33,9 +33,9 @@ class SmartParamTestCase(TestCase): ...@@ -33,9 +33,9 @@ class SmartParamTestCase(TestCase):
def setUp(self): def setUp(self):
params = { params = {
'test_smartparam/choice1/choice': 2, 'test_smartparam/choice1/choice': 2,
'test_smartparam/#{:d}/uniform'.format(lineno1): '5', 'test_smartparam/__line{:d}/uniform'.format(lineno1): '5',
'test_smartparam/func/function_choice': 1, 'test_smartparam/func/function_choice': 1,
'test_smartparam/#{:d}/function_choice'.format(lineno2): 0 'test_smartparam/__line{:d}/function_choice'.format(lineno2): 0
} }
nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params } nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params }
......
This diff is collapsed.
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -69,6 +69,7 @@ def _generate_file_search_space(path, module): ...@@ -69,6 +69,7 @@ def _generate_file_search_space(path, module):
def expand_annotations(src_dir, dst_dir): def expand_annotations(src_dir, dst_dir):
"""Expand annotations in user code. """Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str) src_dir: directory path of user code (str)
dst_dir: directory to place generated files (str) dst_dir: directory to place generated files (str)
""" """
...@@ -77,6 +78,8 @@ def expand_annotations(src_dir, dst_dir): ...@@ -77,6 +78,8 @@ def expand_annotations(src_dir, dst_dir):
if dst_dir[-1] == '/': if dst_dir[-1] == '/':
dst_dir = dst_dir[:-1] dst_dir = dst_dir[:-1]
annotated = False
for src_subdir, dirs, files in os.walk(src_dir): for src_subdir, dirs, files in os.walk(src_dir):
assert src_subdir.startswith(src_dir) assert src_subdir.startswith(src_dir)
dst_subdir = src_subdir.replace(src_dir, dst_dir, 1) dst_subdir = src_subdir.replace(src_dir, dst_dir, 1)
...@@ -86,17 +89,25 @@ def expand_annotations(src_dir, dst_dir): ...@@ -86,17 +89,25 @@ def expand_annotations(src_dir, dst_dir):
src_path = os.path.join(src_subdir, file_name) src_path = os.path.join(src_subdir, file_name)
dst_path = os.path.join(dst_subdir, file_name) dst_path = os.path.join(dst_subdir, file_name)
if file_name.endswith('.py'): if file_name.endswith('.py'):
_expand_file_annotations(src_path, dst_path) annotated |= _expand_file_annotations(src_path, dst_path)
else: else:
shutil.copyfile(src_path, dst_path) shutil.copyfile(src_path, dst_path)
for dir_name in dirs: for dir_name in dirs:
os.makedirs(os.path.join(dst_subdir, dir_name), exist_ok=True) os.makedirs(os.path.join(dst_subdir, dir_name), exist_ok=True)
return dst_dir if annotated else src_dir
def _expand_file_annotations(src_path, dst_path): def _expand_file_annotations(src_path, dst_path):
with open(src_path) as src, open(dst_path, 'w') as dst: with open(src_path) as src, open(dst_path, 'w') as dst:
try: try:
dst.write(code_generator.parse(src.read())) annotated_code = code_generator.parse(src.read())
if annotated_code is None:
shutil.copyfile(src_path, dst_path)
return False
dst.write(annotated_code)
return True
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except
if exc.args: if exc.args:
raise RuntimeError(src_path + ' ' + '\n'.join(exc.args)) raise RuntimeError(src_path + ' ' + '\n'.join(exc.args))
......
...@@ -161,6 +161,7 @@ class Transformer(ast.NodeTransformer): ...@@ -161,6 +161,7 @@ class Transformer(ast.NodeTransformer):
def __init__(self): def __init__(self):
self.stack = [] self.stack = []
self.last_line = 0 self.last_line = 0
self.annotated = False
def visit(self, node): def visit(self, node):
if isinstance(node, (ast.expr, ast.stmt)): if isinstance(node, (ast.expr, ast.stmt)):
...@@ -190,8 +191,9 @@ class Transformer(ast.NodeTransformer): ...@@ -190,8 +191,9 @@ class Transformer(ast.NodeTransformer):
def _visit_string(self, node): def _visit_string(self, node):
string = node.value.s string = node.value.s
if string.startswith('@nni.'):
if not string.startswith('@nni.'): self.annotated = True
else:
return node # not an annotation, ignore it return node # not an annotation, ignore it
if string.startswith('@nni.report_intermediate_result(') \ if string.startswith('@nni.report_intermediate_result(') \
...@@ -216,7 +218,7 @@ class Transformer(ast.NodeTransformer): ...@@ -216,7 +218,7 @@ class Transformer(ast.NodeTransformer):
def parse(code): def parse(code):
"""Annotate user code. """Annotate user code.
Return annotated code (str). Return annotated code (str) if annotation detected; return None if not.
code: original user code (str) code: original user code (str)
""" """
try: try:
...@@ -224,11 +226,15 @@ def parse(code): ...@@ -224,11 +226,15 @@ def parse(code):
except Exception: except Exception:
raise RuntimeError('Bad Python code') raise RuntimeError('Bad Python code')
transformer = Transformer()
try: try:
Transformer().visit(ast_tree) transformer.visit(ast_tree)
except AssertionError as exc: except AssertionError as exc:
raise RuntimeError('%d: %s' % (ast_tree.last_line, exc.args[0])) raise RuntimeError('%d: %s' % (ast_tree.last_line, exc.args[0]))
if not transformer.annotated:
return None
last_future_import = -1 last_future_import = -1
import_nni = ast.Import(names=[ast.alias(name='nni', asname=None)]) import_nni = ast.Import(names=[ast.alias(name='nni', asname=None)])
nodes = ast_tree.body nodes = ast_tree.body
......
...@@ -76,7 +76,7 @@ class SearchSpaceGenerator(ast.NodeVisitor): ...@@ -76,7 +76,7 @@ class SearchSpaceGenerator(ast.NodeVisitor):
else: else:
# generate the missing name automatically # generate the missing name automatically
assert len(node.args) > 0, 'Smart parameter expression has no argument' assert len(node.args) > 0, 'Smart parameter expression has no argument'
name = '#' + str(node.args[-1].lineno) name = '__line' + str(node.args[-1].lineno)
specified_name = False specified_name = False
if func in ('choice', 'function_choice'): if func in ('choice', 'function_choice'):
......
...@@ -27,6 +27,7 @@ import ast ...@@ -27,6 +27,7 @@ import ast
import json import json
import os import os
import shutil import shutil
import tempfile
from unittest import TestCase, main from unittest import TestCase, main
...@@ -43,12 +44,18 @@ class AnnotationTestCase(TestCase): ...@@ -43,12 +44,18 @@ class AnnotationTestCase(TestCase):
self.assertEqual(search_space, json.load(f)) self.assertEqual(search_space, json.load(f))
def test_code_generator(self): def test_code_generator(self):
expand_annotations('testcase/usercode', '_generated') code_dir = expand_annotations('testcase/usercode', '_generated')
self.assertEqual(code_dir, '_generated')
self._assert_source_equal('testcase/annotated/mnist.py', '_generated/mnist.py') self._assert_source_equal('testcase/annotated/mnist.py', '_generated/mnist.py')
self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/dir/simple.py') self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/dir/simple.py')
with open('testcase/usercode/nonpy.txt') as src, open('_generated/nonpy.txt') as dst: with open('testcase/usercode/nonpy.txt') as src, open('_generated/nonpy.txt') as dst:
assert src.read() == dst.read() assert src.read() == dst.read()
def test_annotation_detecting(self):
dir_ = 'testcase/usercode/non_annotation'
code_dir = expand_annotations(dir_, tempfile.mkdtemp())
self.assertEqual(code_dir, dir_)
def _assert_source_equal(self, src1, src2): def _assert_source_equal(self, src1, src2):
with open(src1) as f1, open(src2) as f2: with open(src1) as f1, open(src2) as f2:
ast1 = ast.dump(ast.parse(f1.read())) ast1 = ast.dump(ast.parse(f1.read()))
......
import nni
def bar():
"""I'm doc string"""
return nni.report_final_result(0)
...@@ -3,15 +3,15 @@ ...@@ -3,15 +3,15 @@
"_type": "choice", "_type": "choice",
"_value": [ 0, 1, 2, 3 ] "_value": [ 0, 1, 2, 3 ]
}, },
"handwrite/#5/function_choice": { "handwrite/__line5/function_choice": {
"_type": "choice", "_type": "choice",
"_value": [ 0, 1, 2 ] "_value": [ 0, 1, 2 ]
}, },
"handwrite/#8/qlognormal": { "handwrite/__line8/qlognormal": {
"_type": "qlognormal", "_type": "qlognormal",
"_value": [ 1.2, 3, 4.5 ] "_value": [ 1.2, 3, 4.5 ]
}, },
"handwrite/#13/choice": { "handwrite/__line13/choice": {
"_type": "choice", "_type": "choice",
"_value": [ 0, 1 ] "_value": [ 0, 1 ]
}, },
......
import nni
def bar():
"""I'm doc string"""
return nni.report_final_result(0)
...@@ -26,7 +26,6 @@ import string ...@@ -26,7 +26,6 @@ import string
from subprocess import Popen, PIPE, call from subprocess import Popen, PIPE, call
import tempfile import tempfile
from nni_annotation import * from nni_annotation import *
import random
from .launcher_utils import validate_all_content from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response
from .url_utils import cluster_metadata_url, experiment_url from .url_utils import cluster_metadata_url, experiment_url
...@@ -189,13 +188,11 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No ...@@ -189,13 +188,11 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No
nni_config.set_config('restServerPid', rest_process.pid) nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation # Deal with annotation
if experiment_config.get('useAnnotation'): if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), 'nni', 'annotation', ''.join(random.sample(string.ascii_letters + string.digits, 8))) path = os.path.join(tempfile.gettempdir(), 'nni', 'annotation')
if os.path.isdir(path): path = tempfile.mkdtemp(dir=path)
shutil.rmtree(path) code_dir = expand_annotations(experiment_config['trial']['codeDir'], path)
os.makedirs(path) experiment_config['trial']['codeDir'] = code_dir
expand_annotations(experiment_config['trial']['codeDir'], path) search_space = generate_search_space(code_dir)
experiment_config['trial']['codeDir'] = path
search_space = generate_search_space(experiment_config['trial']['codeDir'])
experiment_config['searchSpace'] = json.dumps(search_space) experiment_config['searchSpace'] = json.dumps(search_space)
assert search_space, ERROR_INFO % 'Generated search space is empty' assert search_space, ERROR_INFO % 'Generated search space is empty'
elif experiment_config.get('searchSpacePath'): elif experiment_config.get('searchSpacePath'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment