Unverified Commit 12410686 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #20 from microsoft/master

pull code
parents 611a45fc 61fec446
......@@ -6,7 +6,7 @@ $CWD = $PWD
echo ""
echo "===========================Testing: nni_annotation==========================="
cd $CWD/../tools/
python -m unittest -v nni_annotation/test_annotation.py
python -m unittest -v nni_annotation/test_annotation.py
## Export certain environment variables for unittest code to work
$env:NNI_TRIAL_JOB_ID="test_trial_job_id"
......
## NNI CTL
The NNI CTL module is used to control Neural Network Intelligence, including start a new experiment, stop an experiment and update an experiment etc.
The NNI CTL module is used to control Neural Network Intelligence, including start a new experiment, stop an experiment and update an experiment etc.
## Environment
```
......@@ -9,7 +9,7 @@ python >= 3.5
## Installation
1. Enter tools directory
1. Enter tools directory
1. Use pip to install packages
* Install for current user:
......@@ -24,17 +24,17 @@ python >= 3.5
python3 -m pip install -e .
```
1. Change the mode of nnictl file
1. Change the mode of nnictl file
```bash
chmod +x ./nnictl
```
1. Add nnictl to your PATH system environment variable.
* You could use `export` command to set PATH variable temporary.
export PATH={your nnictl path}:$PATH
export PATH={your nnictl path}:$PATH
* Or you could edit your `/etc/profile` file.
......
# NNI Annotation
# NNI Annotation
## Overview
To improve user experience and reduce user effort, we design an annotation grammar. Using NNI annotation, users can adapt their code to NNI just by adding some standalone annotating strings, which does not affect the execution of the original code.
To improve user experience and reduce user effort, we design an annotation grammar. Using NNI annotation, users can adapt their code to NNI just by adding some standalone annotating strings, which does not affect the execution of the original code.
Below is an example:
......@@ -28,7 +28,7 @@ In NNI, there are mainly four types of annotation:
**Arguments**
- **sampling_algo**: Sampling algorithm that specifies a search space. User should replace it with a built-in NNI sampling function whose name consists of an `nni.` identification and a search space type specified in [SearchSpaceSpec](https://nni.readthedocs.io/en/latest/SearchSpaceSpec.html) such as `choice` or `uniform`.
- **sampling_algo**: Sampling algorithm that specifies a search space. User should replace it with a built-in NNI sampling function whose name consists of an `nni.` identification and a search space type specified in [SearchSpaceSpec](https://nni.readthedocs.io/en/latest/SearchSpaceSpec.html) such as `choice` or `uniform`.
- **name**: The name of the variable that the selected value will be assigned to. Note that this argument should be the same as the left value of the following assignment statement.
There are 10 types to express your search space as follows:
......
......@@ -22,16 +22,18 @@
import os
import sys
import shutil
import json
from . import code_generator
from . import search_space_generator
from . import specific_code_generator
__all__ = ['generate_search_space', 'expand_annotations']
slash = '/'
if sys.platform == "win32":
slash = '\\'
slash = '\\'
def generate_search_space(code_dir):
"""Generate search space from Python source code.
......@@ -39,7 +41,7 @@ def generate_search_space(code_dir):
code_dir: directory path of source files (str)
"""
search_space = {}
if code_dir.endswith(slash):
code_dir = code_dir[:-1]
......@@ -74,7 +76,7 @@ def _generate_file_search_space(path, module):
return search_space
def expand_annotations(src_dir, dst_dir):
def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
"""Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str)
......@@ -82,7 +84,7 @@ def expand_annotations(src_dir, dst_dir):
"""
if src_dir[-1] == slash:
src_dir = src_dir[:-1]
if dst_dir[-1] == slash:
dst_dir = dst_dir[:-1]
......@@ -93,11 +95,23 @@ def expand_annotations(src_dir, dst_dir):
dst_subdir = src_subdir.replace(src_dir, dst_dir, 1)
os.makedirs(dst_subdir, exist_ok=True)
# generate module name from path
if src_subdir == src_dir:
package = ''
else:
assert src_subdir.startswith(src_dir + slash), src_subdir
prefix_len = len(src_dir) + 1
package = src_subdir[prefix_len:].replace(slash, '.') + '.'
for file_name in files:
src_path = os.path.join(src_subdir, file_name)
dst_path = os.path.join(dst_subdir, file_name)
if file_name.endswith('.py'):
annotated |= _expand_file_annotations(src_path, dst_path)
if trial_id == '':
annotated |= _expand_file_annotations(src_path, dst_path)
else:
module = package + file_name[:-3]
annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module)
else:
shutil.copyfile(src_path, dst_path)
......@@ -121,3 +135,22 @@ def _expand_file_annotations(src_path, dst_path):
raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
else:
raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
def _generate_specific_file(src_path, dst_path, exp_id, trial_id, module):
with open(src_path) as src, open(dst_path, 'w') as dst:
try:
with open(os.path.expanduser('~/nni/experiments/%s/trials/%s/parameter.cfg'%(exp_id, trial_id))) as fd:
para_cfg = json.load(fd)
annotated_code = specific_code_generator.parse(src.read(), para_cfg["parameters"], module)
if annotated_code is None:
shutil.copyfile(src_path, dst_path)
return False
dst.write(annotated_code)
return True
except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
else:
raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
......@@ -60,7 +60,7 @@ def parse_annotation_mutable_layers(code, lineno):
kw_args = []
kw_values = []
for kw in call.keywords:
kw_args.append(kw.arg)
kw_args.append(ast.Str(s=kw.arg))
kw_values.append(kw.value)
call_kwargs_values.append(ast.Dict(keys=kw_args, values=kw_values))
call_funcs = ast.Dict(keys=call_funcs_keys, values=call_funcs_values)
......@@ -79,7 +79,7 @@ def parse_annotation_mutable_layers(code, lineno):
fields['optional_inputs'] = True
elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num, 'Value of optional_input_size should be a number'
assert type(value) is ast.Num or type(value) is ast.List, 'Value of optional_input_size should be a number or list'
optional_input_size = value
fields['optional_input_size'] = True
elif k.id == 'layer_output':
......@@ -102,13 +102,14 @@ def parse_annotation_mutable_layers(code, lineno):
if fields['fixed_inputs']:
target_call_args.append(fixed_inputs)
else:
target_call_args.append(ast.NameConstant(value=None))
target_call_args.append(ast.List(elts=[]))
if fields['optional_inputs']:
target_call_args.append(optional_inputs)
assert fields['optional_input_size'], 'optional_input_size must exist when optional_inputs exists'
target_call_args.append(optional_input_size)
else:
target_call_args.append(ast.NameConstant(value=None))
target_call_args.append(ast.Dict(keys=[], values=[]))
target_call_args.append(ast.Num(n=0))
target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[])
node = ast.Assign(targets=[layer_output], value=target_call)
nodes.append(node)
......@@ -229,7 +230,7 @@ def test_variable_equal(node1, node2):
if len(node1) != len(node2):
return False
return all(test_variable_equal(n1, n2) for n1, n2 in zip(node1, node2))
return node1 == node2
......@@ -314,20 +315,20 @@ class Transformer(ast.NodeTransformer):
else:
return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter('):
if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code."
print_warning(deprecated_message)
if string.startswith('@nni.report_intermediate_result(') \
or string.startswith('@nni.report_final_result(') \
or string.startswith('@nni.get_next_parameter('):
if string.startswith('@nni.report_intermediate_result') \
or string.startswith('@nni.report_final_result') \
or string.startswith('@nni.get_next_parameter'):
return parse_annotation(string[1:]) # expand annotation string to code
if string.startswith('@nni.mutable_layers('):
if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno)
if string.startswith('@nni.variable(') \
or string.startswith('@nni.function_choice('):
if string.startswith('@nni.variable') \
or string.startswith('@nni.function_choice'):
self.stack[-1] = string[1:] # mark that the next expression is annotated
return None
......
......@@ -54,12 +54,14 @@ class SearchSpaceGenerator(ast.NodeTransformer):
def generate_mutable_layer_search_space(self, args):
mutable_block = args[0].s
mutable_layer = args[1].s
if mutable_block not in self.search_space:
self.search_space[mutable_block] = dict()
self.search_space[mutable_block][mutable_layer] = {
'layer_choice': [key.s for key in args[2].keys],
'optional_inputs': [key.s for key in args[5].keys],
'optional_input_size': args[6].n
key = self.module_name + '/' + mutable_block
args[0].s = key
if key not in self.search_space:
self.search_space[key] = dict()
self.search_space[key][mutable_layer] = {
'layer_choice': [k.s for k in args[2].keys],
'optional_inputs': [k.s for k in args[5].keys],
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
}
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import ast
import astor
from nni_cmd.common_utils import print_warning
# pylint: disable=unidiomatic-typecheck
para_cfg = None
prefix_name = None
def parse_annotation_mutable_layers(code, lineno):
"""Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes
code: annotation string (excluding '@')
"""
module = ast.parse(code)
assert type(module) is ast.Module, 'internal error #1'
assert len(module.body) == 1, 'Annotation mutable_layers contains more than one expression'
assert type(module.body[0]) is ast.Expr, 'Annotation is not expression'
call = module.body[0].value
nodes = []
mutable_id = prefix_name + '/mutable_block_' + str(lineno)
mutable_layer_cnt = 0
for arg in call.args:
fields = {'layer_choice': False,
'fixed_inputs': False,
'optional_inputs': False,
'optional_input_size': False,
'layer_output': False}
mutable_layer_id = 'mutable_layer_' + str(mutable_layer_cnt)
mutable_layer_cnt += 1
func_call = None
for k, value in zip(arg.keys, arg.values):
if k.id == 'layer_choice':
assert not fields['layer_choice'], 'Duplicated field: layer_choice'
assert type(value) is ast.List, 'Value of layer_choice should be a list'
for call in value.elts:
assert type(call) is ast.Call, 'Element in layer_choice should be function call'
call_name = astor.to_source(call).strip()
if call_name == para_cfg[mutable_id][mutable_layer_id]['chosen_layer']:
func_call = call
assert not call.args, 'Number of args without keyword should be zero'
break
fields['layer_choice'] = True
elif k.id == 'fixed_inputs':
assert not fields['fixed_inputs'], 'Duplicated field: fixed_inputs'
assert type(value) is ast.List, 'Value of fixed_inputs should be a list'
fixed_inputs = value
fields['fixed_inputs'] = True
elif k.id == 'optional_inputs':
assert not fields['optional_inputs'], 'Duplicated field: optional_inputs'
assert type(value) is ast.List, 'Value of optional_inputs should be a list'
var_names = [astor.to_source(var).strip() for var in value.elts]
chosen_inputs = para_cfg[mutable_id][mutable_layer_id]['chosen_inputs']
elts = []
for i in chosen_inputs:
index = var_names.index(i)
elts.append(value.elts[index])
optional_inputs = ast.List(elts=elts)
fields['optional_inputs'] = True
elif k.id == 'optional_input_size':
pass
elif k.id == 'layer_output':
assert not fields['layer_output'], 'Duplicated field: layer_output'
assert type(value) is ast.Name, 'Value of layer_output should be ast.Name type'
layer_output = value
fields['layer_output'] = True
else:
raise AssertionError('Unexpected field in mutable layer')
# make call for this mutable layer
assert fields['layer_choice'], 'layer_choice must exist'
assert fields['layer_output'], 'layer_output must exist'
if not fields['fixed_inputs']:
fixed_inputs = ast.List(elts=[])
if not fields['optional_inputs']:
optional_inputs = ast.List(elts=[])
inputs = ast.List(elts=[fixed_inputs, optional_inputs])
func_call.args.append(inputs)
node = ast.Assign(targets=[layer_output], value=func_call)
nodes.append(node)
return nodes
def parse_annotation(code):
"""Parse an annotation string.
Return an AST Expr node.
code: annotation string (excluding '@')
"""
module = ast.parse(code)
assert type(module) is ast.Module, 'internal error #1'
assert len(module.body) == 1, 'Annotation contains more than one expression'
assert type(module.body[0]) is ast.Expr, 'Annotation is not expression'
return module.body[0]
def parse_annotation_function(code, func_name):
"""Parse an annotation function.
Return the value of `name` keyword argument and the AST Call node.
func_name: expected function name
"""
expr = parse_annotation(code)
call = expr.value
assert type(call) is ast.Call, 'Annotation is not a function call'
assert type(call.func) is ast.Attribute, 'Unexpected annotation function'
assert type(call.func.value) is ast.Name, 'Invalid annotation function name'
assert call.func.value.id == 'nni', 'Annotation is not a NNI function'
assert call.func.attr == func_name, 'internal error #2'
assert len(call.keywords) == 1, 'Annotation function contains more than one keyword argument'
assert call.keywords[0].arg == 'name', 'Annotation keyword argument is not "name"'
name = call.keywords[0].value
return name, call
def parse_nni_variable(code):
"""Parse `nni.variable` expression.
Return the name argument and AST node of annotated expression.
code: annotation string
"""
name, call = parse_annotation_function(code, 'variable')
assert len(call.args) == 1, 'nni.variable contains more than one arguments'
arg = call.args[0]
assert type(arg) is ast.Call, 'Value of nni.variable is not a function call'
assert type(arg.func) is ast.Attribute, 'nni.variable value is not a NNI function'
assert type(arg.func.value) is ast.Name, 'nni.variable value is not a NNI function'
assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function'
name_str = astor.to_source(name).strip()
keyword_arg = ast.keyword(arg='name', value=ast.Str(s=name_str))
arg.keywords.append(keyword_arg)
if arg.func.attr == 'choice':
convert_args_to_dict(arg)
return name, arg
def parse_nni_function(code):
"""Parse `nni.function_choice` expression.
Return the AST node of annotated expression and a list of dumped function call expressions.
code: annotation string
"""
name, call = parse_annotation_function(code, 'function_choice')
funcs = [ast.dump(func, False) for func in call.args]
convert_args_to_dict(call, with_lambda=True)
name_str = astor.to_source(name).strip()
call.keywords[0].value = ast.Str(s=name_str)
return call, funcs
def convert_args_to_dict(call, with_lambda=False):
"""Convert all args to a dict such that every key and value in the dict is the same as the value of the arg.
Return the AST Call node with only one arg that is the dictionary
"""
keys, values = list(), list()
for arg in call.args:
if type(arg) in [ast.Str, ast.Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
keys.append(arg_value)
values.append(arg)
del call.args[:]
call.args.append(ast.Dict(keys=keys, values=values))
return call
def make_lambda(call):
"""Wrap an AST Call node to lambda expression node.
call: ast.Call node
"""
empty_args = ast.arguments(args=[], vararg=None, kwarg=None, defaults=[])
return ast.Lambda(args=empty_args, body=call)
def test_variable_equal(node1, node2):
"""Test whether two variables are the same."""
if type(node1) is not type(node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in ('lineno', 'col_offset', 'ctx'):
continue
if not test_variable_equal(v, getattr(node2, k)):
return False
return True
if isinstance(node1, list):
if len(node1) != len(node2):
return False
return all(test_variable_equal(n1, n2) for n1, n2 in zip(node1, node2))
return node1 == node2
def replace_variable_node(node, annotation):
"""Replace a node annotated by `nni.variable`.
node: the AST node to replace
annotation: annotation string
"""
assert type(node) is ast.Assign, 'nni.variable is not annotating assignment expression'
assert len(node.targets) == 1, 'Annotated assignment has more than one left-hand value'
name, expr = parse_nni_variable(annotation)
assert test_variable_equal(node.targets[0], name), 'Annotated variable has wrong name'
node.value = expr
return node
def replace_function_node(node, annotation):
"""Replace a node annotated by `nni.function_choice`.
node: the AST node to replace
annotation: annotation string
"""
target, funcs = parse_nni_function(annotation)
FuncReplacer(funcs, target).visit(node)
return node
class FuncReplacer(ast.NodeTransformer):
"""To replace target function call expressions in a node annotated by `nni.function_choice`"""
def __init__(self, funcs, target):
"""Constructor.
funcs: list of dumped function call expressions to replace
target: use this AST node to replace matching expressions
"""
self.funcs = set(funcs)
self.target = target
def visit_Call(self, node): # pylint: disable=invalid-name
if ast.dump(node, False) in self.funcs:
return self.target
return node
class Transformer(ast.NodeTransformer):
"""Transform original code to annotated code"""
def __init__(self):
self.stack = []
self.last_line = 0
self.annotated = False
def visit(self, node):
if isinstance(node, (ast.expr, ast.stmt)):
self.last_line = node.lineno
# do nothing for root
if not self.stack:
return self._visit_children(node)
annotation = self.stack[-1]
# this is a standalone string, may be an annotation
if type(node) is ast.Expr and type(node.value) is ast.Str:
# must not annotate an annotation string
assert annotation is None, 'Annotating an annotation'
return self._visit_string(node)
if annotation is not None: # this expression is annotated
self.stack[-1] = None # so next expression is not
if annotation.startswith('nni.variable'):
return replace_variable_node(node, annotation)
if annotation.startswith('nni.function_choice'):
return replace_function_node(node, annotation)
return self._visit_children(node)
def _visit_string(self, node):
string = node.value.s
if string.startswith('@nni.'):
self.annotated = True
else:
return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code."
print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='Get next parameter here...')], keywords=[]))
if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno)
if string.startswith('@nni.variable') \
or string.startswith('@nni.function_choice'):
self.stack[-1] = string[1:] # mark that the next expression is annotated
return None
raise AssertionError('Unexpected annotation function')
def _visit_children(self, node):
self.stack.append(None)
self.generic_visit(node)
annotation = self.stack.pop()
assert annotation is None, 'Annotation has no target'
return node
def parse(code, para, module):
"""Annotate user code.
Return annotated code (str) if annotation detected; return None if not.
code: original user code (str)
"""
global para_cfg
global prefix_name
para_cfg = para
prefix_name = module
try:
ast_tree = ast.parse(code)
except Exception:
raise RuntimeError('Bad Python code')
transformer = Transformer()
try:
transformer.visit(ast_tree)
except AssertionError as exc:
raise RuntimeError('%d: %s' % (ast_tree.last_line, exc.args[0]))
if not transformer.annotated:
return None
return astor.to_source(ast_tree)
......@@ -161,7 +161,7 @@ def main():
def generate_default_params():
params = {'data_dir': '/tmp/tensorflow/mnist/input_data',
'dropout_rate': 0.5, 'channel_1_num': 32, 'channel_2_num': 64,
'conv_size': 5, 'pool_size': 2, 'hidden_size': 1024, 'batch_size':
'conv_size': 5, 'pool_size': 2, 'hidden_size': 1024, 'batch_size':
50, 'batch_num': 200, 'learning_rate': 0.0001}
return params
......
......@@ -44,7 +44,7 @@ class MnistNetwork(object):
self.x = tf.placeholder(tf.float32, [None, self.x_dim], name = 'input_x')
self.y = tf.placeholder(tf.float32, [None, self.y_dim], name = 'input_y')
self.keep_prob = tf.placeholder(tf.float32, name = 'keep_prob')
# Reshape to use within a convolutional neural net.
# Last dimension is for "features" - there is only one here, since images are
# grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc.
......@@ -55,8 +55,8 @@ class MnistNetwork(object):
#print('input dim cannot be sqrt and reshape. input dim: ' + str(self.x_dim))
logger.debug('input dim cannot be sqrt and reshape. input dim: ' + str(self.x_dim))
raise
x_image = tf.reshape(self.x, [-1, input_dim, input_dim, 1])
x_image = tf.reshape(self.x, [-1, input_dim, input_dim, 1])
# First convolutional layer - maps one grayscale image to 32 feature maps.
with tf.name_scope('conv1'):
W_conv1 = weight_variable([self.conv_size, self.conv_size, 1, self.channel_1_num])
......@@ -68,38 +68,38 @@ class MnistNetwork(object):
with tf.name_scope('pool1'):
"""@nni.function_choice(max_pool(h_conv1, self.pool_size),avg_pool(h_conv1, self.pool_size),name=max_pool)"""
h_pool1 = max_pool(h_conv1, self.pool_size)
# Second convolutional layer -- maps 32 feature maps to 64.
with tf.name_scope('conv2'):
W_conv2 = weight_variable([self.conv_size, self.conv_size, self.channel_1_num, self.channel_2_num])
b_conv2 = bias_variable([self.channel_2_num])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
# Second pooling layer.
with tf.name_scope('pool2'):
#"""@nni.dynamic(input={cnn_block:1, concat:2},function_choice={"cnn_block":(x,nni.choice([3,4])),"cnn_block":(x),"concat":(x,y)},limit={"cnn_block.input":[concat,input],"concat.input":[this.depth-1,this.depth-3,this.depth-5],"graph.width":[1]})"""
h_pool2 = max_pool(h_conv2, self.pool_size)
# Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image
# is down to 7x7x64 feature maps -- maps this to 1024 features.
last_dim = int(input_dim / (self.pool_size * self.pool_size))
with tf.name_scope('fc1'):
W_fc1 = weight_variable([last_dim * last_dim * self.channel_2_num, self.hidden_size])
b_fc1 = bias_variable([self.hidden_size])
h_pool2_flat = tf.reshape(h_pool2, [-1, last_dim * last_dim * self.channel_2_num])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
# Dropout - controls the complexity of the model, prevents co-adaptation of features.
with tf.name_scope('dropout'):
h_fc1_drop = tf.nn.dropout(h_fc1, self.keep_prob)
# Map the 1024 features to 10 classes, one for each digit
with tf.name_scope('fc2'):
W_fc2 = weight_variable([self.hidden_size, self.y_dim])
b_fc2 = bias_variable([self.y_dim])
y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
with tf.name_scope('loss'):
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = self.y, logits = y_conv))
with tf.name_scope('adam_optimizer'):
......@@ -121,7 +121,7 @@ def max_pool(x, pool_size):
strides=[1, pool_size, pool_size, 1], padding='SAME')
def avg_pool(x,pool_size):
return tf.nn.avg_pool(x, ksize=[1, pool_size, pool_size, 1],
strides=[1, pool_size, pool_size, 1], padding='SAME')
strides=[1, pool_size, pool_size, 1], padding='SAME')
def weight_variable(shape):
"""weight_variable generates a weight variable of a given shape."""
......@@ -163,12 +163,12 @@ def main():
'''@nni.variable(nni.choice(1,5),name=dropout_rate)'''
dropout_rate=0.5
mnist_network.train_step.run(feed_dict={mnist_network.x: batch[0], mnist_network.y: batch[1], mnist_network.keep_prob: dropout_rate})
if i % 100 == 0:
#train_accuracy = mnist_network.accuracy.eval(feed_dict={
# mnist_network.x: batch[0], mnist_network.y: batch[1], mnist_network.keep_prob: params['dropout_rate']})
#print('step %d, training accuracy %g' % (i, train_accuracy))
test_acc = mnist_network.accuracy.eval(feed_dict={
mnist_network.x: mnist.test.images, mnist_network.y: mnist.test.labels, mnist_network.keep_prob: 1.0})
'''@nni.report_intermediate_result(test_acc)'''
......@@ -196,7 +196,7 @@ if __name__ == '__main__':
#FLAGS, unparsed = parse_command()
#original_params = parse_init_json(FLAGS.init_file_path, {})
#pipe_interface.set_params_to_env()
'''@nni.get_next_parameter()'''
try:
......
......@@ -128,7 +128,7 @@ advisor_schema_dict = {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('min_budget'): setNumberRange('min_budget', int, 0, 9999),
Optional('max_budget'): setNumberRange('max_budget', int, 0, 9999),
Optional('eta'):setNumberRange('eta', int, 0, 9999),
Optional('eta'):setNumberRange('eta', int, 0, 9999),
Optional('min_points_in_model'): setNumberRange('min_points_in_model', int, 0, 9999),
Optional('top_n_percent'): setNumberRange('top_n_percent', int, 1, 99),
Optional('num_samples'): setNumberRange('num_samples', int, 1, 9999),
......@@ -235,7 +235,7 @@ kubeflow_trial_schema = {
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str)
}
}
}
}
......
......@@ -83,7 +83,7 @@ class Experiments:
self.experiments[id]['fileName'] = file_name
self.experiments[id]['platform'] = platform
self.write_file()
def update_experiment(self, id, key, value):
'''Update experiment'''
if id not in self.experiments:
......@@ -91,17 +91,17 @@ class Experiments:
self.experiments[id][key] = value
self.write_file()
return True
def remove_experiment(self, id):
'''remove an experiment by id'''
if id in self.experiments:
self.experiments.pop(id)
self.write_file()
def get_all_experiments(self):
'''return all of experiments'''
return self.experiments
def write_file(self):
'''save config to local file'''
try:
......
......@@ -39,6 +39,7 @@ import site
import time
from pathlib import Path
from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment
def get_log_path(config_file_name):
'''generate stdout and stderr log path'''
......@@ -102,7 +103,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
print_error('Port %s is used by another process, please reset the port!\n' \
'You could use \'nnictl create --help\' to get help information' % port)
exit(1)
if (platform != 'local') and detect_port(int(port) + 1):
print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
'You could set another port to start experiment!\n' \
......@@ -110,7 +111,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
exit(1)
print_normal('Starting restful server...')
entry_dir = get_nni_installation_path()
entry_file = os.path.join(entry_dir, 'main.js')
......@@ -220,7 +221,7 @@ def setNNIManagerIp(experiment_config, port, config_file_name):
return True, None
def set_pai_config(experiment_config, port, config_file_name):
'''set pai configuration'''
'''set pai configuration'''
pai_config_data = dict()
pai_config_data['pai_config'] = experiment_config['paiConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(pai_config_data), REST_TIME_OUT)
......@@ -239,7 +240,7 @@ def set_pai_config(experiment_config, port, config_file_name):
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_kubeflow_config(experiment_config, port, config_file_name):
'''set kubeflow configuration'''
'''set kubeflow configuration'''
kubeflow_config_data = dict()
kubeflow_config_data['kubeflow_config'] = experiment_config['kubeflowConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(kubeflow_config_data), REST_TIME_OUT)
......@@ -258,7 +259,7 @@ def set_kubeflow_config(experiment_config, port, config_file_name):
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_frameworkcontroller_config(experiment_config, port, config_file_name):
'''set kubeflow configuration'''
'''set kubeflow configuration'''
frameworkcontroller_config_data = dict()
frameworkcontroller_config_data['frameworkcontroller_config'] = experiment_config['frameworkcontrollerConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(frameworkcontroller_config_data), REST_TIME_OUT)
......@@ -318,7 +319,7 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'pai':
request_data['clusterMetaData'].append(
{'key': 'pai_config', 'value': experiment_config['paiConfig']})
{'key': 'pai_config', 'value': experiment_config['paiConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'kubeflow':
......@@ -346,13 +347,6 @@ def set_experiment(experiment_config, mode, port, config_file_name):
def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None):
'''follow steps to start rest server and start experiment'''
nni_config = Config(config_file_name)
# check execution policy in powershell
if sys.platform == 'win32':
execution_policy = check_output(['powershell.exe','Get-ExecutionPolicy']).decode('ascii').strip()
if execution_policy == 'Restricted':
print_error('PowerShell execution policy error, please run PowerShell as administrator with this command first:\r\n'\
+ '\'Set-ExecutionPolicy -ExecutionPolicy Unrestricted\'')
exit(1)
# check packages for tuner
package_name, module_name = None, None
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
......@@ -430,7 +424,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
#set pai config
if experiment_config['trainingServicePlatform'] == 'pai':
print_normal('Setting pai config...')
......@@ -445,7 +439,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
#set kubeflow config
if experiment_config['trainingServicePlatform'] == 'kubeflow':
print_normal('Setting kubeflow config...')
......@@ -461,7 +455,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
#set kubeflow config
#set frameworkcontroller config
if experiment_config['trainingServicePlatform'] == 'frameworkcontroller':
print_normal('Setting frameworkcontroller config...')
config_result, err_msg = set_frameworkcontroller_config(experiment_config, args.port, config_file_name)
......@@ -499,7 +493,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
else:
web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list)
#save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name, experiment_config['trainingServicePlatform'])
......@@ -508,6 +502,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
def resume_experiment(args):
'''resume an experiment'''
update_experiment()
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
experiment_id = None
......
......@@ -21,7 +21,7 @@
import os
import json
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA, FRAMEWORKCONTROLLER_CONFIG_SCHEMA, \
tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from schema import SchemaMissingKeyError, SchemaForbiddenKeyError, SchemaUnexpectedTypeError, SchemaWrongKeyError, SchemaError
from .common_utils import get_json_content, print_error, print_warning, print_normal
from schema import Schema, And, Use, Optional, Regex, Or
......@@ -62,7 +62,7 @@ def parse_path(experiment_config, config_path):
expand_path(experiment_config['assessor'], 'codeDir')
if experiment_config.get('advisor'):
expand_path(experiment_config['advisor'], 'codeDir')
#if users use relative path, convert it to absolute path
root_path = os.path.dirname(config_path)
if experiment_config.get('searchSpacePath'):
......@@ -80,8 +80,8 @@ def parse_path(experiment_config, config_path):
parse_relative_path(root_path, experiment_config['machineList'][index], 'sshKeyPath')
def validate_search_space_content(experiment_config):
'''Validate searchspace content,
if the searchspace file is not json format or its values does not contain _type and _value which must be specified,
'''Validate searchspace content,
if the searchspace file is not json format or its values does not contain _type and _value which must be specified,
it will not be a valid searchspace file'''
try:
search_space_content = json.load(open(experiment_config.get('searchSpacePath'), 'r'))
......@@ -110,7 +110,7 @@ def validate_kubeflow_operators(experiment_config):
if experiment_config.get('trial').get('master') is None:
print_error('kubeflow with pytorch-operator must set master')
exit(1)
if experiment_config.get('kubeflowConfig').get('storage') == 'nfs':
if experiment_config.get('kubeflowConfig').get('nfs') is None:
print_error('please set nfs configuration!')
......@@ -170,7 +170,7 @@ def validate_common_content(experiment_config):
else:
print_error(error)
exit(1)
#set default value
if experiment_config.get('maxExecDuration') is None:
experiment_config['maxExecDuration'] = '999d'
......
......@@ -83,7 +83,7 @@ def parse_args():
parser_updater_duration.add_argument('--value', '-v', required=True, help='the unit of time should in {\'s\', \'m\', \'h\', \'d\'}')
parser_updater_duration.set_defaults(func=update_duration)
parser_updater_trialnum = parser_updater_subparsers.add_parser('trialnum', help='update maxtrialnum')
parser_updater_trialnum.add_argument('--id', '-i', dest='id', help='the id of experiment')
parser_updater_trialnum.add_argument('id', nargs='?', help='the id of experiment')
parser_updater_trialnum.add_argument('--value', '-v', required=True)
parser_updater_trialnum.set_defaults(func=update_trialnum)
......@@ -103,6 +103,10 @@ def parse_args():
parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_kill.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to be killed')
parser_trial_kill.set_defaults(func=trial_kill)
parser_trial_codegen = parser_trial_subparsers.add_parser('codegen', help='generate trial code for a specific trial')
parser_trial_codegen.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_codegen.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to do code generation')
parser_trial_codegen.set_defaults(func=trial_codegen)
#parse experiment command
parser_experiment = subparsers.add_parser('experiment', help='get experiment information')
......@@ -172,7 +176,7 @@ def parse_args():
parser_package_subparsers = parser_package.add_subparsers()
parser_package_install = parser_package_subparsers.add_parser('install', help='install packages')
parser_package_install.add_argument('--name', '-n', dest='name', help='package name to be installed')
parser_package_install.set_defaults(func=package_install)
parser_package_install.set_defaults(func=package_install)
parser_package_show = parser_package_subparsers.add_parser('show', help='show the information of packages')
parser_package_show.set_defaults(func=package_show)
......
......@@ -25,6 +25,7 @@ import json
import datetime
import time
from subprocess import call, check_output
from nni_annotation import expand_annotations
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url
from .config_utils import Config, Experiments
......@@ -264,7 +265,7 @@ def trial_kill(args):
return
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_delete(trial_job_id_url(rest_port, args.id), REST_TIME_OUT)
response = rest_delete(trial_job_id_url(rest_port, args.trial_id), REST_TIME_OUT)
if response and check_response(response):
print(response.text)
else:
......@@ -272,6 +273,17 @@ def trial_kill(args):
else:
print_error('Restful server is not running...')
def trial_codegen(args):
'''Generate code for a specific trial'''
print_warning('Currently, this command is only for nni nas programming interface.')
exp_id = check_experiment_id(args)
nni_config = Config(get_config_filename(args))
if not nni_config.get_config('experimentConfig')['useAnnotation']:
print_error('The experiment is not using annotation')
exit(1)
code_dir = nni_config.get_config('experimentConfig')['trial']['codeDir']
expand_annotations(code_dir, './exp_%s_trial_%s_code'%(exp_id, args.trial_id), exp_id, args.trial_id)
def list_experiment(args):
'''Get experiment information'''
nni_config = Config(get_config_filename(args))
......@@ -309,7 +321,7 @@ def log_internal(args, filetype):
else:
file_full_path = os.path.join(NNICTL_HOME_DIR, file_name, 'stderr')
print(check_output_command(file_full_path, head=args.head, tail=args.tail))
def log_stdout(args):
'''get stdout log'''
log_internal(args, 'stdout')
......@@ -381,7 +393,7 @@ def experiment_list(args):
print_warning('There is no experiment running...\nYou can use \'nnictl experiment list all\' to list all stopped experiments!')
experiment_information = ""
for key in experiment_id_list:
experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], experiment_dict[key]['port'],\
experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
......
......@@ -36,8 +36,8 @@ def process_install(package_name):
def package_install(args):
'''install packages'''
process_install(args.name)
def package_show(args):
'''show all packages'''
print(' '.join(PACKAGE_REQUIREMENTS.keys()))
......@@ -112,7 +112,7 @@ def update_concurrency(args):
print_error('Update %s failed!' % 'concurrency')
def update_duration(args):
#parse time, change time unit to seconds
#parse time, change time unit to seconds
args.value = parse_time(args.value)
args.port = get_experiment_port(args)
if args.port is not None:
......
......@@ -40,16 +40,16 @@ def copyHdfsDirectoryToLocal(hdfsDirectory, localDirectory, hdfsClient):
copyHdfsDirectoryToLocal(subHdfsDirectory, subLocalDirectory, hdfsClient)
elif f.type == 'FILE':
hdfsFilePath = posixpath.join(hdfsDirectory, f.pathSuffix)
localFilePath = os.path.join(localDirectory, f.pathSuffix)
localFilePath = os.path.join(localDirectory, f.pathSuffix)
copyHdfsFileToLocal(hdfsFilePath, localFilePath, hdfsClient)
else:
else:
raise AssertionError('unexpected type {}'.format(f.type))
def copyHdfsFileToLocal(hdfsFilePath, localFilePath, hdfsClient, override=True):
'''Copy file from HDFS to local'''
if not hdfsClient.exists(hdfsFilePath):
raise Exception('HDFS file {} does not exist!'.format(hdfsFilePath))
try:
try:
file_status = hdfsClient.get_file_status(hdfsFilePath)
if file_status.type != 'FILE':
raise Exception('HDFS file path {} is not a file'.format(hdfsFilePath))
......
......@@ -142,7 +142,7 @@ class PipeLogReader(threading.Thread):
'''
time.sleep(5)
while True:
cur_process_exit = self.process_exit
cur_process_exit = self.process_exit
try:
line = self.queue.get(True, 5)
try:
......@@ -150,7 +150,7 @@ class PipeLogReader(threading.Thread):
except Exception as e:
pass
except Exception as e:
if cur_process_exit == True:
if cur_process_exit == True:
self._is_read_completed = True
break
......@@ -177,7 +177,7 @@ class PipeLogReader(threading.Thread):
if not self.log_pattern.match(line):
continue
self.queue.put(line)
self.pipeReader.close()
def close(self):
......@@ -190,7 +190,7 @@ class PipeLogReader(threading.Thread):
"""Return if read is completed
"""
return self._is_read_completed
def set_process_exit(self):
self.process_exit = True
return self.process_exit
\ No newline at end of file
......@@ -39,9 +39,9 @@ class HDFSClientUtilityTest(unittest.TestCase):
self.hdfs_config = json.load(file)
except Exception as exception:
print(exception)
self.hdfs_client = HdfsClient(hosts='{0}:{1}'.format(self.hdfs_config['host'], '50070'), user_name=self.hdfs_config['userName'])
def get_random_name(self, length):
return ''.join(random.sample(string.ascii_letters + string.digits, length))
......@@ -49,20 +49,20 @@ class HDFSClientUtilityTest(unittest.TestCase):
'''test copyFileToHdfs'''
file_name = self.get_random_name(8)
file_content = 'hello world!'
with open('./{}'.format(file_name), 'w') as file:
file.write(file_content)
file.write(file_content)
result = copyFileToHdfs('./{}'.format(file_name), '/{0}/{1}'.format(self.hdfs_config['userName'], file_name), self.hdfs_client)
self.assertTrue(result)
file_list = self.hdfs_client.listdir('/{0}'.format(self.hdfs_config['userName']))
self.assertIn(file_name, file_list)
hdfs_file_name = self.get_random_name(8)
self.hdfs_client.copy_to_local('/{0}/{1}'.format(self.hdfs_config['userName'], file_name), './{}'.format(hdfs_file_name))
self.assertTrue(os.path.exists('./{}'.format(hdfs_file_name)))
with open('./{}'.format(hdfs_file_name), 'r') as file:
content = file.readline()
self.assertEqual(file_content, content)
......@@ -70,21 +70,21 @@ class HDFSClientUtilityTest(unittest.TestCase):
os.remove('./{}'.format(file_name))
os.remove('./{}'.format(hdfs_file_name))
self.hdfs_client.delete('/{0}/{1}'.format(self.hdfs_config['userName'], file_name))
def test_copy_directory_run(self):
'''test copyDirectoryToHdfs'''
directory_name = self.get_random_name(8)
file_name_list = [self.get_random_name(8), self.get_random_name(8)]
file_content = 'hello world!'
os.makedirs('./{}'.format(directory_name))
for file_name in file_name_list:
with open('./{0}/{1}'.format(directory_name, file_name), 'w') as file:
file.write(file_content)
result = copyDirectoryToHdfs('./{}'.format(directory_name), '/{0}/{1}'.format(self.hdfs_config['userName'], directory_name), self.hdfs_client)
self.assertTrue(result)
directory_list = self.hdfs_client.listdir('/{0}'.format(self.hdfs_config['userName']))
self.assertIn(directory_name, directory_list)
......@@ -94,7 +94,7 @@ class HDFSClientUtilityTest(unittest.TestCase):
#clean up
self.hdfs_client.delete('/{0}/{1}/{2}'.format(self.hdfs_config['userName'], directory_name, file_name))
self.hdfs_client.delete('/{0}/{1}'.format(self.hdfs_config['userName'], directory_name))
shutil.rmtree('./{}'.format(directory_name))
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment