Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9f40659d
Commit
9f40659d
authored
Dec 27, 2019
by
Yuge Zhang
Committed by
QuanluZhang
Dec 27, 2019
Browse files
Fix a few issues related to fixed arc and from-tuner arc (#1876)
parent
db91e8e6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
23 deletions
+24
-23
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
+18
-12
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+6
-11
No files found.
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
View file @
9f40659d
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
import
nni
import
nni
from
nni.env_vars
import
trial_env_vars
from
nni.env_vars
import
trial_env_vars
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
,
MutableScope
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutator
import
Mutator
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -104,10 +104,11 @@ class ClassicMutator(Mutator):
...
@@ -104,10 +104,11 @@ class ClassicMutator(Mutator):
search_space_item : list
search_space_item : list
The list for corresponding search space.
The list for corresponding search space.
"""
"""
candidate_repr
=
search_space_item
[
"candidates"
]
multihot_list
=
[
False
]
*
mutable
.
n_candidates
multihot_list
=
[
False
]
*
mutable
.
n_candidates
for
i
,
v
in
zip
(
idx
,
value
):
for
i
,
v
in
zip
(
idx
,
value
):
assert
0
<=
i
<
mutable
.
n_candidates
and
search_space_item
[
i
]
==
v
,
\
assert
0
<=
i
<
mutable
.
n_candidates
and
candidate_repr
[
i
]
==
v
,
\
"Index '{}' in search space '{}' is not '{}'"
.
format
(
i
,
search_space_item
,
v
)
"Index '{}' in search space '{}' is not '{}'"
.
format
(
i
,
candidate_repr
,
v
)
assert
not
multihot_list
[
i
],
"'{}' is selected twice in '{}', which is not allowed."
.
format
(
i
,
idx
)
assert
not
multihot_list
[
i
],
"'{}' is selected twice in '{}', which is not allowed."
.
format
(
i
,
idx
)
multihot_list
[
i
]
=
True
multihot_list
[
i
]
=
True
return
torch
.
tensor
(
multihot_list
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
return
torch
.
tensor
(
multihot_list
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
...
@@ -121,17 +122,20 @@ class ClassicMutator(Mutator):
...
@@ -121,17 +122,20 @@ class ClassicMutator(Mutator):
self
.
_chosen_arch
.
keys
())
self
.
_chosen_arch
.
keys
())
result
=
dict
()
result
=
dict
()
for
mutable
in
self
.
mutables
:
for
mutable
in
self
.
mutables
:
assert
mutable
.
key
in
self
.
_chosen_arch
,
"Expected '{}' in chosen arch, but not found."
.
format
(
mutable
.
key
)
if
isinstance
(
mutable
,
(
LayerChoice
,
InputChoice
)):
assert
mutable
.
key
in
self
.
_chosen_arch
,
\
"Expected '{}' in chosen arch, but not found."
.
format
(
mutable
.
key
)
data
=
self
.
_chosen_arch
[
mutable
.
key
]
data
=
self
.
_chosen_arch
[
mutable
.
key
]
assert
isinstance
(
data
,
dict
)
and
"_value"
in
data
and
"_idx"
in
data
,
\
assert
isinstance
(
data
,
dict
)
and
"_value"
in
data
and
"_idx"
in
data
,
\
"'{}' is not a valid choice."
.
format
(
data
)
"'{}' is not a valid choice."
.
format
(
data
)
value
=
data
[
"_value"
]
idx
=
data
[
"_idx"
]
search_space_item
=
self
.
_search_space
[
mutable
.
key
][
"_value"
]
if
isinstance
(
mutable
,
LayerChoice
):
if
isinstance
(
mutable
,
LayerChoice
):
result
[
mutable
.
key
]
=
self
.
_sample_layer_choice
(
mutable
,
idx
,
value
,
search_space_item
)
result
[
mutable
.
key
]
=
self
.
_sample_layer_choice
(
mutable
,
data
[
"_idx"
],
data
[
"_value"
],
self
.
_search_space
[
mutable
.
key
][
"_value"
])
elif
isinstance
(
mutable
,
InputChoice
):
elif
isinstance
(
mutable
,
InputChoice
):
result
[
mutable
.
key
]
=
self
.
_sample_input_choice
(
mutable
,
idx
,
value
,
search_space_item
)
result
[
mutable
.
key
]
=
self
.
_sample_input_choice
(
mutable
,
data
[
"_idx"
],
data
[
"_value"
],
self
.
_search_space
[
mutable
.
key
][
"_value"
])
elif
isinstance
(
mutable
,
MutableScope
):
logger
.
info
(
"Mutable scope '%s' is skipped during parsing choices."
,
mutable
.
key
)
else
:
else
:
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
return
result
return
result
...
@@ -190,6 +194,8 @@ class ClassicMutator(Mutator):
...
@@ -190,6 +194,8 @@ class ClassicMutator(Mutator):
search_space
[
key
]
=
{
"_type"
:
INPUT_CHOICE
,
search_space
[
key
]
=
{
"_type"
:
INPUT_CHOICE
,
"_value"
:
{
"candidates"
:
mutable
.
choose_from
,
"_value"
:
{
"candidates"
:
mutable
.
choose_from
,
"n_chosen"
:
mutable
.
n_chosen
}}
"n_chosen"
:
mutable
.
n_chosen
}}
elif
isinstance
(
mutable
,
MutableScope
):
logger
.
info
(
"Mutable scope '%s' is skipped during generating search space."
,
mutable
.
key
)
else
:
else
:
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
return
search_space
return
search_space
...
...
src/sdk/pynni/nni/nas/pytorch/fixed.py
View file @
9f40659d
...
@@ -41,18 +41,18 @@ class FixedArchitecture(Mutator):
...
@@ -41,18 +41,18 @@ class FixedArchitecture(Mutator):
return
self
.
_fixed_arc
return
self
.
_fixed_arc
def
_encode_tensor
(
data
,
device
):
def
_encode_tensor
(
data
):
if
isinstance
(
data
,
list
):
if
isinstance
(
data
,
list
):
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
,
device
=
device
)
# pylint: disable=not-callable
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
else
:
else
:
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
,
device
=
device
)
# pylint: disable=not-callable
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
)
# pylint: disable=not-callable
if
isinstance
(
data
,
dict
):
if
isinstance
(
data
,
dict
):
return
{
k
:
_encode_tensor
(
v
,
device
)
for
k
,
v
in
data
.
items
()}
return
{
k
:
_encode_tensor
(
v
)
for
k
,
v
in
data
.
items
()}
return
data
return
data
def
apply_fixed_architecture
(
model
,
fixed_arc_path
,
device
=
None
):
def
apply_fixed_architecture
(
model
,
fixed_arc_path
):
"""
"""
Load architecture from `fixed_arc_path` and apply to model.
Load architecture from `fixed_arc_path` and apply to model.
...
@@ -62,21 +62,16 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
...
@@ -62,21 +62,16 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
Model with mutables.
Model with mutables.
fixed_arc_path : str
fixed_arc_path : str
Path to the JSON that stores the architecture.
Path to the JSON that stores the architecture.
device : torch.device
Architecture weights will be transfered to `device`.
Returns
Returns
-------
-------
FixedArchitecture
FixedArchitecture
"""
"""
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
if
isinstance
(
fixed_arc_path
,
str
):
if
isinstance
(
fixed_arc_path
,
str
):
with
open
(
fixed_arc_path
,
"r"
)
as
f
:
with
open
(
fixed_arc_path
,
"r"
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
_encode_tensor
(
fixed_arc
,
device
)
fixed_arc
=
_encode_tensor
(
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
.
to
(
device
)
architecture
.
reset
()
architecture
.
reset
()
return
architecture
return
architecture
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment