Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7108466c
Commit
7108466c
authored
Mar 08, 2019
by
Zejun Lin
Committed by
QuanluZhang
Mar 07, 2019
Browse files
fix annotation key-error (#806)
* fix annotation, resolve annotation's key err bug, refactor the design
parent
f10c3311
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
67 deletions
+50
-67
src/sdk/pynni/nni/smartparam.py
src/sdk/pynni/nni/smartparam.py
+25
-37
src/sdk/pynni/tests/test_smartparam.py
src/sdk/pynni/tests/test_smartparam.py
+8
-23
tools/nni_annotation/__init__.py
tools/nni_annotation/__init__.py
+4
-1
tools/nni_annotation/search_space_generator.py
tools/nni_annotation/search_space_generator.py
+13
-6
No files found.
src/sdk/pynni/nni/smartparam.py
View file @
7108466c
...
...
@@ -82,52 +82,40 @@ if env_args.platform is None:
else
:
def
choice
(
options
,
name
=
None
):
return
options
[
_get_param
(
'choice'
,
name
)]
def
choice
(
options
,
name
=
None
,
key
=
None
):
return
options
[
_get_param
(
key
)]
def
randint
(
upper
,
name
=
None
):
return
_get_param
(
'randint'
,
name
)
def
randint
(
upper
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
uniform
(
low
,
high
,
name
=
None
):
return
_get_param
(
'uniform'
,
name
)
def
uniform
(
low
,
high
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
quniform
(
low
,
high
,
q
,
name
=
None
):
return
_get_param
(
'quniform'
,
name
)
def
quniform
(
low
,
high
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
loguniform
(
low
,
high
,
name
=
None
):
return
_get_param
(
'loguniform'
,
name
)
def
loguniform
(
low
,
high
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
qloguniform
(
low
,
high
,
q
,
name
=
None
):
return
_get_param
(
'qloguniform'
,
name
)
def
qloguniform
(
low
,
high
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
normal
(
mu
,
sigma
,
name
=
None
):
return
_get_param
(
'normal'
,
name
)
def
normal
(
mu
,
sigma
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
qnormal
(
mu
,
sigma
,
q
,
name
=
None
):
return
_get_param
(
'qnormal'
,
name
)
def
qnormal
(
mu
,
sigma
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
lognormal
(
mu
,
sigma
,
name
=
None
):
return
_get_param
(
'lognormal'
,
name
)
def
lognormal
(
mu
,
sigma
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
qlognormal
(
mu
,
sigma
,
q
,
name
=
None
):
return
_get_param
(
'qlognormal'
,
name
)
def
function_choice
(
funcs
,
name
=
None
):
return
funcs
[
_get_param
(
'function_choice'
,
name
)]()
def
_get_param
(
func
,
name
):
# frames:
# layer 0: this function
# layer 1: the API function (caller of this function)
# layer 2: caller of the API function
frame
=
inspect
.
stack
(
0
)[
2
]
filename
=
frame
.
filename
lineno
=
frame
.
lineno
# NOTE: this is the lineno of caller's last argument
del
frame
# see official doc
module
=
inspect
.
getmodulename
(
filename
)
if
name
is
None
:
name
=
'__line{:d}'
.
format
(
lineno
)
key
=
'{}/{}/{}'
.
format
(
module
,
name
,
func
)
def
qlognormal
(
mu
,
sigma
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
key
)
def
function_choice
(
funcs
,
name
=
None
,
key
=
None
):
return
funcs
[
_get_param
(
key
)]()
def
_get_param
(
key
):
if
trial
.
_params
is
None
:
trial
.
get_next_parameter
()
return
trial
.
get_current_parameter
(
key
)
src/sdk/pynni/tests/test_smartparam.py
View file @
7108466c
...
...
@@ -29,8 +29,6 @@ import nni.trial
from
unittest
import
TestCase
,
main
lineno1
=
61
lineno2
=
75
class
SmartParamTestCase
(
TestCase
):
def
setUp
(
self
):
...
...
@@ -39,43 +37,30 @@ class SmartParamTestCase(TestCase):
'test_smartparam/choice2/choice'
:
'3*2+1'
,
'test_smartparam/choice3/choice'
:
'[1, 2]'
,
'test_smartparam/choice4/choice'
:
'{"a", 2}'
,
'test_smartparam/__line{:d}/uniform'
.
format
(
lineno1
):
'5'
,
'test_smartparam/func/function_choice'
:
'bar'
,
'test_smartparam/lambda_func/function_choice'
:
"lambda: 2*3"
,
'test_smartparam/__line{:d}/function_choice'
.
format
(
lineno2
):
'max(1, 2, 3)'
'test_smartparam/lambda_func/function_choice'
:
"lambda: 2*3"
}
nni
.
trial
.
_params
=
{
'parameter_id'
:
'test_trial'
,
'parameters'
:
params
}
def
test_specified_name
(
self
):
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice1'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice1'
,
key
=
'test_smartparam/choice1/choice'
)
self
.
assertEqual
(
val
,
'a'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice2'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice2'
,
key
=
'test_smartparam/choice2/choice'
)
self
.
assertEqual
(
val
,
7
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice3'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice3'
,
key
=
'test_smartparam/choice3/choice'
)
self
.
assertEqual
(
val
,
[
1
,
2
])
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice4'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice4'
,
key
=
'test_smartparam/choice4/choice'
)
self
.
assertEqual
(
val
,
{
"a"
,
2
})
def
test_default_name
(
self
):
val
=
nni
.
uniform
(
1
,
10
)
# NOTE: assign this line number to lineno1
self
.
assertEqual
(
val
,
'5'
)
def
test_specified_name_func
(
self
):
val
=
nni
.
function_choice
({
'foo'
:
foo
,
'bar'
:
bar
},
name
=
'func'
)
def
test_func
(
self
):
val
=
nni
.
function_choice
({
'foo'
:
foo
,
'bar'
:
bar
},
name
=
'func'
,
key
=
'test_smartparam/func/function_choice'
)
self
.
assertEqual
(
val
,
'bar'
)
def
test_lambda_func
(
self
):
val
=
nni
.
function_choice
({
"lambda: 2*3"
:
lambda
:
2
*
3
,
"lambda: 3*4"
:
lambda
:
3
*
4
},
name
=
'lambda_func'
)
val
=
nni
.
function_choice
({
"lambda: 2*3"
:
lambda
:
2
*
3
,
"lambda: 3*4"
:
lambda
:
3
*
4
},
name
=
'lambda_func'
,
key
=
'test_smartparam/lambda_func/function_choice'
)
self
.
assertEqual
(
val
,
6
)
def
test_default_name_func
(
self
):
val
=
nni
.
function_choice
({
'max(1, 2, 3)'
:
lambda
:
max
(
1
,
2
,
3
),
'min(1, 2)'
:
lambda
:
min
(
1
,
2
)
# NOTE: assign this line number to lineno2
})
self
.
assertEqual
(
val
,
3
)
def
foo
():
return
'foo'
...
...
tools/nni_annotation/__init__.py
View file @
7108466c
...
...
@@ -59,12 +59,15 @@ def generate_search_space(code_dir):
def
_generate_file_search_space
(
path
,
module
):
with
open
(
path
)
as
src
:
try
:
return
search_space_generator
.
generate
(
module
,
src
.
read
())
search_space
,
code
=
search_space_generator
.
generate
(
module
,
src
.
read
())
except
Exception
as
exc
:
# pylint: disable=broad-except
if
exc
.
args
:
raise
RuntimeError
(
path
+
' '
+
'
\n
'
.
join
(
exc
.
args
))
else
:
raise
RuntimeError
(
'Failed to generate search space for %s: %r'
%
(
path
,
exc
))
with
open
(
path
,
'w'
)
as
dst
:
dst
.
write
(
code
)
return
search_space
def
expand_annotations
(
src_dir
,
dst_dir
):
...
...
tools/nni_annotation/search_space_generator.py
View file @
7108466c
...
...
@@ -20,6 +20,7 @@
import
ast
import
astor
# pylint: disable=unidiomatic-typecheck
...
...
@@ -40,7 +41,7 @@ _ss_funcs = [
]
class
SearchSpaceGenerator
(
ast
.
Node
Visito
r
):
class
SearchSpaceGenerator
(
ast
.
Node
Transforme
r
):
"""Generate search space from smart parater APIs"""
def
__init__
(
self
,
module_name
):
...
...
@@ -53,16 +54,16 @@ class SearchSpaceGenerator(ast.NodeVisitor):
# ignore if the function is not 'nni.*'
if
type
(
node
.
func
)
is
not
ast
.
Attribute
:
return
return
node
if
type
(
node
.
func
.
value
)
is
not
ast
.
Name
:
return
return
node
if
node
.
func
.
value
.
id
!=
'nni'
:
return
return
node
# ignore if its not a search space function (e.g. `report_final_result`)
func
=
node
.
func
.
attr
if
func
not
in
_ss_funcs
:
return
return
node
self
.
last_line
=
node
.
lineno
...
...
@@ -77,6 +78,7 @@ class SearchSpaceGenerator(ast.NodeVisitor):
# generate the missing name automatically
name
=
'__line'
+
str
(
str
(
node
.
args
[
-
1
].
lineno
))
specified_name
=
False
node
.
keywords
=
list
()
if
func
in
(
'choice'
,
'function_choice'
):
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
...
...
@@ -89,6 +91,9 @@ class SearchSpaceGenerator(ast.NodeVisitor):
args
=
[
arg
.
n
for
arg
in
node
.
args
]
key
=
self
.
module_name
+
'/'
+
name
+
'/'
+
func
# store key in ast.Call
node
.
keywords
.
append
(
ast
.
keyword
(
arg
=
'key'
,
value
=
ast
.
Str
(
s
=
key
)))
if
func
==
'function_choice'
:
func
=
'choice'
value
=
{
'_type'
:
func
,
'_value'
:
args
}
...
...
@@ -103,6 +108,8 @@ class SearchSpaceGenerator(ast.NodeVisitor):
self
.
search_space
[
key
]
=
value
return
node
def
generate
(
module_name
,
code
):
"""Generate search space.
...
...
@@ -120,4 +127,4 @@ def generate(module_name, code):
visitor
.
visit
(
ast_tree
)
except
AssertionError
as
exc
:
raise
RuntimeError
(
'%d: %s'
%
(
visitor
.
last_line
,
exc
.
args
[
0
]))
return
visitor
.
search_space
return
visitor
.
search_space
,
astor
.
to_source
(
ast_tree
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment