Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7108466c
Commit
7108466c
authored
Mar 08, 2019
by
Zejun Lin
Committed by
QuanluZhang
Mar 07, 2019
Browse files
fix annotation key-error (#806)
* fix annotation, resolve annotation's key err bug, refactor the design
parent
f10c3311
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
67 deletions
+50
-67
src/sdk/pynni/nni/smartparam.py
src/sdk/pynni/nni/smartparam.py
+25
-37
src/sdk/pynni/tests/test_smartparam.py
src/sdk/pynni/tests/test_smartparam.py
+8
-23
tools/nni_annotation/__init__.py
tools/nni_annotation/__init__.py
+4
-1
tools/nni_annotation/search_space_generator.py
tools/nni_annotation/search_space_generator.py
+13
-6
No files found.
src/sdk/pynni/nni/smartparam.py
View file @
7108466c
...
...
@@ -82,52 +82,40 @@ if env_args.platform is None:
else
:
def
choice
(
options
,
name
=
None
):
return
options
[
_get_param
(
'choice'
,
name
)]
def
choice
(
options
,
name
=
None
,
key
=
None
):
return
options
[
_get_param
(
key
)]
def
randint
(
upper
,
name
=
None
):
return
_get_param
(
'randint'
,
name
)
def
randint
(
upper
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
uniform
(
low
,
high
,
name
=
None
):
return
_get_param
(
'uniform'
,
name
)
def
uniform
(
low
,
high
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
quniform
(
low
,
high
,
q
,
name
=
None
):
return
_get_param
(
'quniform'
,
name
)
def
quniform
(
low
,
high
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
loguniform
(
low
,
high
,
name
=
None
):
return
_get_param
(
'loguniform'
,
name
)
def
loguniform
(
low
,
high
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
qloguniform
(
low
,
high
,
q
,
name
=
None
):
return
_get_param
(
'qloguniform'
,
name
)
def
qloguniform
(
low
,
high
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
normal
(
mu
,
sigma
,
name
=
None
):
return
_get_param
(
'normal'
,
name
)
def
normal
(
mu
,
sigma
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
qnormal
(
mu
,
sigma
,
q
,
name
=
None
):
return
_get_param
(
'qnormal'
,
name
)
def
qnormal
(
mu
,
sigma
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
lognormal
(
mu
,
sigma
,
name
=
None
):
return
_get_param
(
'lognormal'
,
name
)
def
lognormal
(
mu
,
sigma
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
qlognormal
(
mu
,
sigma
,
q
,
name
=
None
):
return
_get_param
(
'qlognormal'
,
name
)
def
function_choice
(
funcs
,
name
=
None
):
return
funcs
[
_get_param
(
'function_choice'
,
name
)]()
def
_get_param
(
func
,
name
):
# frames:
# layer 0: this function
# layer 1: the API function (caller of this function)
# layer 2: caller of the API function
frame
=
inspect
.
stack
(
0
)[
2
]
filename
=
frame
.
filename
lineno
=
frame
.
lineno
# NOTE: this is the lineno of caller's last argument
del
frame
# see official doc
module
=
inspect
.
getmodulename
(
filename
)
if
name
is
None
:
name
=
'__line{:d}'
.
format
(
lineno
)
key
=
'{}/{}/{}'
.
format
(
module
,
name
,
func
)
def
qlognormal
(
mu
,
sigma
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
function_choice
(
funcs
,
name
=
None
,
key
=
None
):
return
funcs
[
_get_param
(
key
)]()
def
_get_param
(
key
):
if
trial
.
_params
is
None
:
trial
.
get_next_parameter
()
return
trial
.
get_current_parameter
(
key
)
src/sdk/pynni/tests/test_smartparam.py
View file @
7108466c
...
...
@@ -29,8 +29,6 @@ import nni.trial
from
unittest
import
TestCase
,
main
lineno1
=
61
lineno2
=
75
class
SmartParamTestCase
(
TestCase
):
def
setUp
(
self
):
...
...
@@ -39,43 +37,30 @@ class SmartParamTestCase(TestCase):
'test_smartparam/choice2/choice'
:
'3*2+1'
,
'test_smartparam/choice3/choice'
:
'[1, 2]'
,
'test_smartparam/choice4/choice'
:
'{"a", 2}'
,
'test_smartparam/__line{:d}/uniform'
.
format
(
lineno1
):
'5'
,
'test_smartparam/func/function_choice'
:
'bar'
,
'test_smartparam/lambda_func/function_choice'
:
"lambda: 2*3"
,
'test_smartparam/__line{:d}/function_choice'
.
format
(
lineno2
):
'max(1, 2, 3)'
'test_smartparam/lambda_func/function_choice'
:
"lambda: 2*3"
}
nni
.
trial
.
_params
=
{
'parameter_id'
:
'test_trial'
,
'parameters'
:
params
}
def
test_specified_name
(
self
):
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice1'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice1'
,
key
=
'test_smartparam/choice1/choice'
)
self
.
assertEqual
(
val
,
'a'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice2'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice2'
,
key
=
'test_smartparam/choice2/choice'
)
self
.
assertEqual
(
val
,
7
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice3'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice3'
,
key
=
'test_smartparam/choice3/choice'
)
self
.
assertEqual
(
val
,
[
1
,
2
])
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice4'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice4'
,
key
=
'test_smartparam/choice4/choice'
)
self
.
assertEqual
(
val
,
{
"a"
,
2
})
def
test_default_name
(
self
):
val
=
nni
.
uniform
(
1
,
10
)
# NOTE: assign this line number to lineno1
self
.
assertEqual
(
val
,
'5'
)
def
test_specified_name_func
(
self
):
val
=
nni
.
function_choice
({
'foo'
:
foo
,
'bar'
:
bar
},
name
=
'func'
)
def
test_func
(
self
):
val
=
nni
.
function_choice
({
'foo'
:
foo
,
'bar'
:
bar
},
name
=
'func'
,
key
=
'test_smartparam/func/function_choice'
)
self
.
assertEqual
(
val
,
'bar'
)
def
test_lambda_func
(
self
):
val
=
nni
.
function_choice
({
"lambda: 2*3"
:
lambda
:
2
*
3
,
"lambda: 3*4"
:
lambda
:
3
*
4
},
name
=
'lambda_func'
)
val
=
nni
.
function_choice
({
"lambda: 2*3"
:
lambda
:
2
*
3
,
"lambda: 3*4"
:
lambda
:
3
*
4
},
name
=
'lambda_func'
,
key
=
'test_smartparam/lambda_func/function_choice'
)
self
.
assertEqual
(
val
,
6
)
def
test_default_name_func
(
self
):
val
=
nni
.
function_choice
({
'max(1, 2, 3)'
:
lambda
:
max
(
1
,
2
,
3
),
'min(1, 2)'
:
lambda
:
min
(
1
,
2
)
# NOTE: assign this line number to lineno2
})
self
.
assertEqual
(
val
,
3
)
def
foo
():
return
'foo'
...
...
tools/nni_annotation/__init__.py
View file @
7108466c
...
...
@@ -59,12 +59,15 @@ def generate_search_space(code_dir):
def
_generate_file_search_space
(
path
,
module
):
with
open
(
path
)
as
src
:
try
:
return
search_space_generator
.
generate
(
module
,
src
.
read
())
search_space
,
code
=
search_space_generator
.
generate
(
module
,
src
.
read
())
except
Exception
as
exc
:
# pylint: disable=broad-except
if
exc
.
args
:
raise
RuntimeError
(
path
+
' '
+
'
\n
'
.
join
(
exc
.
args
))
else
:
raise
RuntimeError
(
'Failed to generate search space for %s: %r'
%
(
path
,
exc
))
with
open
(
path
,
'w'
)
as
dst
:
dst
.
write
(
code
)
return
search_space
def
expand_annotations
(
src_dir
,
dst_dir
):
...
...
tools/nni_annotation/search_space_generator.py
View file @
7108466c
...
...
@@ -20,6 +20,7 @@
import
ast
import
astor
# pylint: disable=unidiomatic-typecheck
...
...
@@ -40,7 +41,7 @@ _ss_funcs = [
]
class
SearchSpaceGenerator
(
ast
.
Node
Visito
r
):
class
SearchSpaceGenerator
(
ast
.
Node
Transforme
r
):
"""Generate search space from smart parater APIs"""
def
__init__
(
self
,
module_name
):
...
...
@@ -53,16 +54,16 @@ class SearchSpaceGenerator(ast.NodeVisitor):
# ignore if the function is not 'nni.*'
if
type
(
node
.
func
)
is
not
ast
.
Attribute
:
return
return
node
if
type
(
node
.
func
.
value
)
is
not
ast
.
Name
:
return
return
node
if
node
.
func
.
value
.
id
!=
'nni'
:
return
return
node
# ignore if its not a search space function (e.g. `report_final_result`)
func
=
node
.
func
.
attr
if
func
not
in
_ss_funcs
:
return
return
node
self
.
last_line
=
node
.
lineno
...
...
@@ -77,6 +78,7 @@ class SearchSpaceGenerator(ast.NodeVisitor):
# generate the missing name automatically
name
=
'__line'
+
str
(
str
(
node
.
args
[
-
1
].
lineno
))
specified_name
=
False
node
.
keywords
=
list
()
if
func
in
(
'choice'
,
'function_choice'
):
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
...
...
@@ -89,6 +91,9 @@ class SearchSpaceGenerator(ast.NodeVisitor):
args
=
[
arg
.
n
for
arg
in
node
.
args
]
key
=
self
.
module_name
+
'/'
+
name
+
'/'
+
func
# store key in ast.Call
node
.
keywords
.
append
(
ast
.
keyword
(
arg
=
'key'
,
value
=
ast
.
Str
(
s
=
key
)))
if
func
==
'function_choice'
:
func
=
'choice'
value
=
{
'_type'
:
func
,
'_value'
:
args
}
...
...
@@ -103,6 +108,8 @@ class SearchSpaceGenerator(ast.NodeVisitor):
self
.
search_space
[
key
]
=
value
return
node
def
generate
(
module_name
,
code
):
"""Generate search space.
...
...
@@ -120,4 +127,4 @@ def generate(module_name, code):
visitor
.
visit
(
ast_tree
)
except
AssertionError
as
exc
:
raise
RuntimeError
(
'%d: %s'
%
(
visitor
.
last_line
,
exc
.
args
[
0
]))
return
visitor
.
search_space
return
visitor
.
search_space
,
astor
.
to_source
(
ast_tree
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment