Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7108466c
Commit
7108466c
authored
Mar 08, 2019
by
Zejun Lin
Committed by
QuanluZhang
Mar 07, 2019
Browse files
fix annotation key-error (#806)
* fix annotation, resolve annotation's key err bug, refactor the design
parent
f10c3311
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
67 deletions
+50
-67
src/sdk/pynni/nni/smartparam.py
src/sdk/pynni/nni/smartparam.py
+25
-37
src/sdk/pynni/tests/test_smartparam.py
src/sdk/pynni/tests/test_smartparam.py
+8
-23
tools/nni_annotation/__init__.py
tools/nni_annotation/__init__.py
+4
-1
tools/nni_annotation/search_space_generator.py
tools/nni_annotation/search_space_generator.py
+13
-6
No files found.
src/sdk/pynni/nni/smartparam.py
View file @
7108466c
...
@@ -82,52 +82,40 @@ if env_args.platform is None:
...
@@ -82,52 +82,40 @@ if env_args.platform is None:
else
:
else
:
def
choice
(
options
,
name
=
None
):
def
choice
(
options
,
name
=
None
,
key
=
None
):
return
options
[
_get_param
(
'choice'
,
name
)]
return
options
[
_get_param
(
key
)]
def
randint
(
upper
,
name
=
None
):
def
randint
(
upper
,
name
=
None
,
key
=
None
):
return
_get_param
(
'randint'
,
name
)
return
_get_param
(
key
)
def
uniform
(
low
,
high
,
name
=
None
):
def
uniform
(
low
,
high
,
name
=
None
,
key
=
None
):
return
_get_param
(
'uniform'
,
name
)
return
_get_param
(
key
)
def
quniform
(
low
,
high
,
q
,
name
=
None
):
def
quniform
(
low
,
high
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
'quniform'
,
name
)
return
_get_param
(
key
)
def
loguniform
(
low
,
high
,
name
=
None
):
def
loguniform
(
low
,
high
,
name
=
None
,
key
=
None
):
return
_get_param
(
'loguniform'
,
name
)
return
_get_param
(
key
)
def
qloguniform
(
low
,
high
,
q
,
name
=
None
):
def
qloguniform
(
low
,
high
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
'qloguniform'
,
name
)
return
_get_param
(
key
)
def
normal
(
mu
,
sigma
,
name
=
None
):
def
normal
(
mu
,
sigma
,
name
=
None
,
key
=
None
):
return
_get_param
(
'normal'
,
name
)
return
_get_param
(
key
)
def
qnormal
(
mu
,
sigma
,
q
,
name
=
None
):
def
qnormal
(
mu
,
sigma
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
'qnormal'
,
name
)
return
_get_param
(
key
)
def
lognormal
(
mu
,
sigma
,
name
=
None
):
def
lognormal
(
mu
,
sigma
,
name
=
None
,
key
=
None
):
return
_get_param
(
'lognormal'
,
name
)
return
_get_param
(
key
)
def
qlognormal
(
mu
,
sigma
,
q
,
name
=
None
):
def
qlognormal
(
mu
,
sigma
,
q
,
name
=
None
,
key
=
None
):
return
_get_param
(
'qlognormal'
,
name
)
return
_get_param
(
key
)
def
function_choice
(
funcs
,
name
=
None
):
def
function_choice
(
funcs
,
name
=
None
,
key
=
None
):
return
funcs
[
_get_param
(
'function_choice'
,
name
)]()
return
funcs
[
_get_param
(
key
)]()
def
_get_param
(
func
,
name
):
def
_get_param
(
key
):
# frames:
# layer 0: this function
# layer 1: the API function (caller of this function)
# layer 2: caller of the API function
frame
=
inspect
.
stack
(
0
)[
2
]
filename
=
frame
.
filename
lineno
=
frame
.
lineno
# NOTE: this is the lineno of caller's last argument
del
frame
# see official doc
module
=
inspect
.
getmodulename
(
filename
)
if
name
is
None
:
name
=
'__line{:d}'
.
format
(
lineno
)
key
=
'{}/{}/{}'
.
format
(
module
,
name
,
func
)
if
trial
.
_params
is
None
:
if
trial
.
_params
is
None
:
trial
.
get_next_parameter
()
trial
.
get_next_parameter
()
return
trial
.
get_current_parameter
(
key
)
return
trial
.
get_current_parameter
(
key
)
src/sdk/pynni/tests/test_smartparam.py
View file @
7108466c
...
@@ -29,8 +29,6 @@ import nni.trial
...
@@ -29,8 +29,6 @@ import nni.trial
from
unittest
import
TestCase
,
main
from
unittest
import
TestCase
,
main
lineno1
=
61
lineno2
=
75
class
SmartParamTestCase
(
TestCase
):
class
SmartParamTestCase
(
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -39,43 +37,30 @@ class SmartParamTestCase(TestCase):
...
@@ -39,43 +37,30 @@ class SmartParamTestCase(TestCase):
'test_smartparam/choice2/choice'
:
'3*2+1'
,
'test_smartparam/choice2/choice'
:
'3*2+1'
,
'test_smartparam/choice3/choice'
:
'[1, 2]'
,
'test_smartparam/choice3/choice'
:
'[1, 2]'
,
'test_smartparam/choice4/choice'
:
'{"a", 2}'
,
'test_smartparam/choice4/choice'
:
'{"a", 2}'
,
'test_smartparam/__line{:d}/uniform'
.
format
(
lineno1
):
'5'
,
'test_smartparam/func/function_choice'
:
'bar'
,
'test_smartparam/func/function_choice'
:
'bar'
,
'test_smartparam/lambda_func/function_choice'
:
"lambda: 2*3"
,
'test_smartparam/lambda_func/function_choice'
:
"lambda: 2*3"
'test_smartparam/__line{:d}/function_choice'
.
format
(
lineno2
):
'max(1, 2, 3)'
}
}
nni
.
trial
.
_params
=
{
'parameter_id'
:
'test_trial'
,
'parameters'
:
params
}
nni
.
trial
.
_params
=
{
'parameter_id'
:
'test_trial'
,
'parameters'
:
params
}
def
test_specified_name
(
self
):
def
test_specified_name
(
self
):
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice1'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice1'
,
key
=
'test_smartparam/choice1/choice'
)
self
.
assertEqual
(
val
,
'a'
)
self
.
assertEqual
(
val
,
'a'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice2'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice2'
,
key
=
'test_smartparam/choice2/choice'
)
self
.
assertEqual
(
val
,
7
)
self
.
assertEqual
(
val
,
7
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice3'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice3'
,
key
=
'test_smartparam/choice3/choice'
)
self
.
assertEqual
(
val
,
[
1
,
2
])
self
.
assertEqual
(
val
,
[
1
,
2
])
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice4'
)
val
=
nni
.
choice
({
'a'
:
'a'
,
'3*2+1'
:
3
*
2
+
1
,
'[1, 2]'
:
[
1
,
2
],
'{"a", 2}'
:
{
"a"
,
2
}},
name
=
'choice4'
,
key
=
'test_smartparam/choice4/choice'
)
self
.
assertEqual
(
val
,
{
"a"
,
2
})
self
.
assertEqual
(
val
,
{
"a"
,
2
})
def
test_default_name
(
self
):
def
test_func
(
self
):
val
=
nni
.
uniform
(
1
,
10
)
# NOTE: assign this line number to lineno1
val
=
nni
.
function_choice
({
'foo'
:
foo
,
'bar'
:
bar
},
name
=
'func'
,
key
=
'test_smartparam/func/function_choice'
)
self
.
assertEqual
(
val
,
'5'
)
def
test_specified_name_func
(
self
):
val
=
nni
.
function_choice
({
'foo'
:
foo
,
'bar'
:
bar
},
name
=
'func'
)
self
.
assertEqual
(
val
,
'bar'
)
self
.
assertEqual
(
val
,
'bar'
)
def
test_lambda_func
(
self
):
def
test_lambda_func
(
self
):
val
=
nni
.
function_choice
({
"lambda: 2*3"
:
lambda
:
2
*
3
,
"lambda: 3*4"
:
lambda
:
3
*
4
},
name
=
'lambda_func'
)
val
=
nni
.
function_choice
({
"lambda: 2*3"
:
lambda
:
2
*
3
,
"lambda: 3*4"
:
lambda
:
3
*
4
},
name
=
'lambda_func'
,
key
=
'test_smartparam/lambda_func/function_choice'
)
self
.
assertEqual
(
val
,
6
)
self
.
assertEqual
(
val
,
6
)
def
test_default_name_func
(
self
):
val
=
nni
.
function_choice
({
'max(1, 2, 3)'
:
lambda
:
max
(
1
,
2
,
3
),
'min(1, 2)'
:
lambda
:
min
(
1
,
2
)
# NOTE: assign this line number to lineno2
})
self
.
assertEqual
(
val
,
3
)
def
foo
():
def
foo
():
return
'foo'
return
'foo'
...
...
tools/nni_annotation/__init__.py
View file @
7108466c
...
@@ -59,12 +59,15 @@ def generate_search_space(code_dir):
...
@@ -59,12 +59,15 @@ def generate_search_space(code_dir):
def
_generate_file_search_space
(
path
,
module
):
def
_generate_file_search_space
(
path
,
module
):
with
open
(
path
)
as
src
:
with
open
(
path
)
as
src
:
try
:
try
:
return
search_space_generator
.
generate
(
module
,
src
.
read
())
search_space
,
code
=
search_space_generator
.
generate
(
module
,
src
.
read
())
except
Exception
as
exc
:
# pylint: disable=broad-except
except
Exception
as
exc
:
# pylint: disable=broad-except
if
exc
.
args
:
if
exc
.
args
:
raise
RuntimeError
(
path
+
' '
+
'
\n
'
.
join
(
exc
.
args
))
raise
RuntimeError
(
path
+
' '
+
'
\n
'
.
join
(
exc
.
args
))
else
:
else
:
raise
RuntimeError
(
'Failed to generate search space for %s: %r'
%
(
path
,
exc
))
raise
RuntimeError
(
'Failed to generate search space for %s: %r'
%
(
path
,
exc
))
with
open
(
path
,
'w'
)
as
dst
:
dst
.
write
(
code
)
return
search_space
def
expand_annotations
(
src_dir
,
dst_dir
):
def
expand_annotations
(
src_dir
,
dst_dir
):
...
...
tools/nni_annotation/search_space_generator.py
View file @
7108466c
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
import
ast
import
ast
import
astor
# pylint: disable=unidiomatic-typecheck
# pylint: disable=unidiomatic-typecheck
...
@@ -40,7 +41,7 @@ _ss_funcs = [
...
@@ -40,7 +41,7 @@ _ss_funcs = [
]
]
class
SearchSpaceGenerator
(
ast
.
Node
Visito
r
):
class
SearchSpaceGenerator
(
ast
.
Node
Transforme
r
):
"""Generate search space from smart parater APIs"""
"""Generate search space from smart parater APIs"""
def
__init__
(
self
,
module_name
):
def
__init__
(
self
,
module_name
):
...
@@ -53,16 +54,16 @@ class SearchSpaceGenerator(ast.NodeVisitor):
...
@@ -53,16 +54,16 @@ class SearchSpaceGenerator(ast.NodeVisitor):
# ignore if the function is not 'nni.*'
# ignore if the function is not 'nni.*'
if
type
(
node
.
func
)
is
not
ast
.
Attribute
:
if
type
(
node
.
func
)
is
not
ast
.
Attribute
:
return
return
node
if
type
(
node
.
func
.
value
)
is
not
ast
.
Name
:
if
type
(
node
.
func
.
value
)
is
not
ast
.
Name
:
return
return
node
if
node
.
func
.
value
.
id
!=
'nni'
:
if
node
.
func
.
value
.
id
!=
'nni'
:
return
return
node
# ignore if its not a search space function (e.g. `report_final_result`)
# ignore if its not a search space function (e.g. `report_final_result`)
func
=
node
.
func
.
attr
func
=
node
.
func
.
attr
if
func
not
in
_ss_funcs
:
if
func
not
in
_ss_funcs
:
return
return
node
self
.
last_line
=
node
.
lineno
self
.
last_line
=
node
.
lineno
...
@@ -77,6 +78,7 @@ class SearchSpaceGenerator(ast.NodeVisitor):
...
@@ -77,6 +78,7 @@ class SearchSpaceGenerator(ast.NodeVisitor):
# generate the missing name automatically
# generate the missing name automatically
name
=
'__line'
+
str
(
str
(
node
.
args
[
-
1
].
lineno
))
name
=
'__line'
+
str
(
str
(
node
.
args
[
-
1
].
lineno
))
specified_name
=
False
specified_name
=
False
node
.
keywords
=
list
()
if
func
in
(
'choice'
,
'function_choice'
):
if
func
in
(
'choice'
,
'function_choice'
):
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
...
@@ -89,6 +91,9 @@ class SearchSpaceGenerator(ast.NodeVisitor):
...
@@ -89,6 +91,9 @@ class SearchSpaceGenerator(ast.NodeVisitor):
args
=
[
arg
.
n
for
arg
in
node
.
args
]
args
=
[
arg
.
n
for
arg
in
node
.
args
]
key
=
self
.
module_name
+
'/'
+
name
+
'/'
+
func
key
=
self
.
module_name
+
'/'
+
name
+
'/'
+
func
# store key in ast.Call
node
.
keywords
.
append
(
ast
.
keyword
(
arg
=
'key'
,
value
=
ast
.
Str
(
s
=
key
)))
if
func
==
'function_choice'
:
if
func
==
'function_choice'
:
func
=
'choice'
func
=
'choice'
value
=
{
'_type'
:
func
,
'_value'
:
args
}
value
=
{
'_type'
:
func
,
'_value'
:
args
}
...
@@ -103,6 +108,8 @@ class SearchSpaceGenerator(ast.NodeVisitor):
...
@@ -103,6 +108,8 @@ class SearchSpaceGenerator(ast.NodeVisitor):
self
.
search_space
[
key
]
=
value
self
.
search_space
[
key
]
=
value
return
node
def
generate
(
module_name
,
code
):
def
generate
(
module_name
,
code
):
"""Generate search space.
"""Generate search space.
...
@@ -120,4 +127,4 @@ def generate(module_name, code):
...
@@ -120,4 +127,4 @@ def generate(module_name, code):
visitor
.
visit
(
ast_tree
)
visitor
.
visit
(
ast_tree
)
except
AssertionError
as
exc
:
except
AssertionError
as
exc
:
raise
RuntimeError
(
'%d: %s'
%
(
visitor
.
last_line
,
exc
.
args
[
0
]))
raise
RuntimeError
(
'%d: %s'
%
(
visitor
.
last_line
,
exc
.
args
[
0
]))
return
visitor
.
search_space
return
visitor
.
search_space
,
astor
.
to_source
(
ast_tree
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment