Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9f40659d
"...resnet50_tensorflow.git" did not exist on "c3db5a9f3cb15b0742dbc4dbae09704f4e0565b8"
Commit
9f40659d
authored
Dec 27, 2019
by
Yuge Zhang
Committed by
QuanluZhang
Dec 27, 2019
Browse files
Fix a few issues related to fixed arc and from-tuner arc (#1876)
parent
db91e8e6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
23 deletions
+24
-23
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
+18
-12
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+6
-11
No files found.
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
View file @
9f40659d
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
import
nni
import
nni
from
nni.env_vars
import
trial_env_vars
from
nni.env_vars
import
trial_env_vars
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
,
MutableScope
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutator
import
Mutator
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -104,10 +104,11 @@ class ClassicMutator(Mutator):
...
@@ -104,10 +104,11 @@ class ClassicMutator(Mutator):
search_space_item : list
search_space_item : list
The list for corresponding search space.
The list for corresponding search space.
"""
"""
candidate_repr
=
search_space_item
[
"candidates"
]
multihot_list
=
[
False
]
*
mutable
.
n_candidates
multihot_list
=
[
False
]
*
mutable
.
n_candidates
for
i
,
v
in
zip
(
idx
,
value
):
for
i
,
v
in
zip
(
idx
,
value
):
assert
0
<=
i
<
mutable
.
n_candidates
and
search_space_item
[
i
]
==
v
,
\
assert
0
<=
i
<
mutable
.
n_candidates
and
candidate_repr
[
i
]
==
v
,
\
"Index '{}' in search space '{}' is not '{}'"
.
format
(
i
,
search_space_item
,
v
)
"Index '{}' in search space '{}' is not '{}'"
.
format
(
i
,
candidate_repr
,
v
)
assert
not
multihot_list
[
i
],
"'{}' is selected twice in '{}', which is not allowed."
.
format
(
i
,
idx
)
assert
not
multihot_list
[
i
],
"'{}' is selected twice in '{}', which is not allowed."
.
format
(
i
,
idx
)
multihot_list
[
i
]
=
True
multihot_list
[
i
]
=
True
return
torch
.
tensor
(
multihot_list
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
return
torch
.
tensor
(
multihot_list
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
...
@@ -121,17 +122,20 @@ class ClassicMutator(Mutator):
...
@@ -121,17 +122,20 @@ class ClassicMutator(Mutator):
self
.
_chosen_arch
.
keys
())
self
.
_chosen_arch
.
keys
())
result
=
dict
()
result
=
dict
()
for
mutable
in
self
.
mutables
:
for
mutable
in
self
.
mutables
:
assert
mutable
.
key
in
self
.
_chosen_arch
,
"Expected '{}' in chosen arch, but not found."
.
format
(
mutable
.
key
)
if
isinstance
(
mutable
,
(
LayerChoice
,
InputChoice
)):
data
=
self
.
_chosen_arch
[
mutable
.
key
]
assert
mutable
.
key
in
self
.
_chosen_arch
,
\
assert
isinstance
(
data
,
dict
)
and
"_value"
in
data
and
"_idx"
in
data
,
\
"Expected '{}' in chosen arch, but not found."
.
format
(
mutable
.
key
)
"'{}' is not a valid choice."
.
format
(
data
)
data
=
self
.
_chosen_arch
[
mutable
.
key
]
value
=
data
[
"_value"
]
assert
isinstance
(
data
,
dict
)
and
"_value"
in
data
and
"_idx"
in
data
,
\
idx
=
data
[
"_idx"
]
"'{}' is not a valid choice."
.
format
(
data
)
search_space_item
=
self
.
_search_space
[
mutable
.
key
][
"_value"
]
if
isinstance
(
mutable
,
LayerChoice
):
if
isinstance
(
mutable
,
LayerChoice
):
result
[
mutable
.
key
]
=
self
.
_sample_layer_choice
(
mutable
,
idx
,
value
,
search_space_item
)
result
[
mutable
.
key
]
=
self
.
_sample_layer_choice
(
mutable
,
data
[
"_idx"
],
data
[
"_value"
],
self
.
_search_space
[
mutable
.
key
][
"_value"
])
elif
isinstance
(
mutable
,
InputChoice
):
elif
isinstance
(
mutable
,
InputChoice
):
result
[
mutable
.
key
]
=
self
.
_sample_input_choice
(
mutable
,
idx
,
value
,
search_space_item
)
result
[
mutable
.
key
]
=
self
.
_sample_input_choice
(
mutable
,
data
[
"_idx"
],
data
[
"_value"
],
self
.
_search_space
[
mutable
.
key
][
"_value"
])
elif
isinstance
(
mutable
,
MutableScope
):
logger
.
info
(
"Mutable scope '%s' is skipped during parsing choices."
,
mutable
.
key
)
else
:
else
:
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
return
result
return
result
...
@@ -190,6 +194,8 @@ class ClassicMutator(Mutator):
...
@@ -190,6 +194,8 @@ class ClassicMutator(Mutator):
search_space
[
key
]
=
{
"_type"
:
INPUT_CHOICE
,
search_space
[
key
]
=
{
"_type"
:
INPUT_CHOICE
,
"_value"
:
{
"candidates"
:
mutable
.
choose_from
,
"_value"
:
{
"candidates"
:
mutable
.
choose_from
,
"n_chosen"
:
mutable
.
n_chosen
}}
"n_chosen"
:
mutable
.
n_chosen
}}
elif
isinstance
(
mutable
,
MutableScope
):
logger
.
info
(
"Mutable scope '%s' is skipped during generating search space."
,
mutable
.
key
)
else
:
else
:
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
return
search_space
return
search_space
...
...
src/sdk/pynni/nni/nas/pytorch/fixed.py
View file @
9f40659d
...
@@ -41,18 +41,18 @@ class FixedArchitecture(Mutator):
...
@@ -41,18 +41,18 @@ class FixedArchitecture(Mutator):
return
self
.
_fixed_arc
return
self
.
_fixed_arc
def
_encode_tensor
(
data
,
device
):
def
_encode_tensor
(
data
):
if
isinstance
(
data
,
list
):
if
isinstance
(
data
,
list
):
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
,
device
=
device
)
# pylint: disable=not-callable
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
else
:
else
:
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
,
device
=
device
)
# pylint: disable=not-callable
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
)
# pylint: disable=not-callable
if
isinstance
(
data
,
dict
):
if
isinstance
(
data
,
dict
):
return
{
k
:
_encode_tensor
(
v
,
device
)
for
k
,
v
in
data
.
items
()}
return
{
k
:
_encode_tensor
(
v
)
for
k
,
v
in
data
.
items
()}
return
data
return
data
def
apply_fixed_architecture
(
model
,
fixed_arc_path
,
device
=
None
):
def
apply_fixed_architecture
(
model
,
fixed_arc_path
):
"""
"""
Load architecture from `fixed_arc_path` and apply to model.
Load architecture from `fixed_arc_path` and apply to model.
...
@@ -62,21 +62,16 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
...
@@ -62,21 +62,16 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
Model with mutables.
Model with mutables.
fixed_arc_path : str
fixed_arc_path : str
Path to the JSON that stores the architecture.
Path to the JSON that stores the architecture.
device : torch.device
Architecture weights will be transfered to `device`.
Returns
Returns
-------
-------
FixedArchitecture
FixedArchitecture
"""
"""
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
if
isinstance
(
fixed_arc_path
,
str
):
if
isinstance
(
fixed_arc_path
,
str
):
with
open
(
fixed_arc_path
,
"r"
)
as
f
:
with
open
(
fixed_arc_path
,
"r"
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
_encode_tensor
(
fixed_arc
,
device
)
fixed_arc
=
_encode_tensor
(
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
.
to
(
device
)
architecture
.
reset
()
architecture
.
reset
()
return
architecture
return
architecture
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment