Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bf7daa8f
Unverified
Commit
bf7daa8f
authored
May 11, 2020
by
Yuge Zhang
Committed by
GitHub
May 11, 2020
Browse files
Prettify the export format of NAS trainer (#2389)
parent
af800213
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
29 deletions
+135
-29
docs/en_US/NAS/NasGuide.md
docs/en_US/NAS/NasGuide.md
+14
-3
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+33
-18
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+77
-8
src/sdk/pynni/nni/nas/pytorch/utils.py
src/sdk/pynni/nni/nas/pytorch/utils.py
+11
-0
No files found.
docs/en_US/NAS/NasGuide.md
View file @
bf7daa8f
...
...
@@ -156,12 +156,23 @@ model = Net()
apply_fixed_architecture
(
model
,
"model_dir/final_architecture.json"
)
```
The JSON is simply a mapping from mutable keys to one-hot or multi-hot representation of choices. For example
The JSON is simply a mapping from mutable keys to choices. Choices can be expressed in:
*
A string: select the candidate with corresponding name.
*
A number: select the candidate with corresponding index.
*
A list of string: select the candidates with corresponding names.
*
A list of number: select the candidates with corresponding indices.
*
A list of boolean values: a multi-hot array.
For example,
```
json
{
"LayerChoice1"
:
[
false
,
true
,
false
,
false
],
"InputChoice2"
:
[
true
,
true
,
false
]
"LayerChoice1"
:
"conv5x5"
,
"LayerChoice2"
:
6
,
"InputChoice3"
:
[
"layer1"
,
"layer3"
],
"InputChoice4"
:
[
1
,
2
],
"InputChoice5"
:
[
false
,
true
,
false
,
false
,
true
]
}
```
...
...
src/sdk/pynni/nni/nas/pytorch/fixed.py
View file @
bf7daa8f
...
...
@@ -3,10 +3,9 @@
import
json
import
torch
from
nni.nas.pytorch.mutables
import
MutableScope
from
nni.nas.pytorch.mutator
import
Mutator
from
.mutables
import
InputChoice
,
LayerChoice
,
MutableScope
from
.mutator
import
Mutator
from
.utils
import
to_list
class
FixedArchitecture
(
Mutator
):
...
...
@@ -17,8 +16,8 @@ class FixedArchitecture(Mutator):
----------
model : nn.Module
A mutable network.
fixed_arc :
str or
dict
P
ath to the architecture checkpoint (a string), or p
reloaded architecture object
(a dict)
.
fixed_arc : dict
Preloaded architecture object.
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
"""
...
...
@@ -33,6 +32,34 @@ class FixedArchitecture(Mutator):
raise
RuntimeError
(
"Unexpected keys found in fixed architecture: {}."
.
format
(
fixed_arc_keys
-
mutable_keys
))
if
mutable_keys
-
fixed_arc_keys
:
raise
RuntimeError
(
"Missing keys in fixed architecture: {}."
.
format
(
mutable_keys
-
fixed_arc_keys
))
self
.
_fixed_arc
=
self
.
_from_human_readable_architecture
(
self
.
_fixed_arc
)
def
_from_human_readable_architecture
(
self
,
human_arc
):
# convert from an exported architecture
result_arc
=
{
k
:
to_list
(
v
)
for
k
,
v
in
human_arc
.
items
()}
# there could be tensors, numpy arrays, etc.
# First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"},
# which means {"op1": [0, ]} ir {"op1": ["conv", ]}
result_arc
=
{
k
:
v
if
isinstance
(
v
,
list
)
else
[
v
]
for
k
,
v
in
result_arc
.
items
()}
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format.
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true].
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length.
for
mutable
in
self
.
mutables
:
if
mutable
.
key
not
in
result_arc
:
continue
# skip silently
choice_arr
=
result_arc
[
mutable
.
key
]
if
all
(
isinstance
(
v
,
bool
)
for
v
in
choice_arr
)
or
all
(
isinstance
(
v
,
float
)
for
v
in
choice_arr
):
if
(
isinstance
(
mutable
,
LayerChoice
)
and
len
(
mutable
)
==
len
(
choice_arr
))
or
\
(
isinstance
(
mutable
,
InputChoice
)
and
mutable
.
n_candidates
==
len
(
choice_arr
)):
# multihot, do nothing
continue
if
isinstance
(
mutable
,
LayerChoice
):
choice_arr
=
[
mutable
.
names
.
index
(
val
)
if
isinstance
(
val
,
str
)
else
val
for
val
in
choice_arr
]
choice_arr
=
[
i
in
choice_arr
for
i
in
range
(
len
(
mutable
))]
elif
isinstance
(
mutable
,
InputChoice
):
choice_arr
=
[
mutable
.
choose_from
.
index
(
val
)
if
isinstance
(
val
,
str
)
else
val
for
val
in
choice_arr
]
choice_arr
=
[
i
in
choice_arr
for
i
in
range
(
mutable
.
n_candidates
)]
result_arc
[
mutable
.
key
]
=
choice_arr
return
result_arc
def
sample_search
(
self
):
"""
...
...
@@ -47,17 +74,6 @@ class FixedArchitecture(Mutator):
return
self
.
_fixed_arc
def
_encode_tensor
(
data
):
if
isinstance
(
data
,
list
):
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
else
:
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
)
# pylint: disable=not-callable
if
isinstance
(
data
,
dict
):
return
{
k
:
_encode_tensor
(
v
)
for
k
,
v
in
data
.
items
()}
return
data
def
apply_fixed_architecture
(
model
,
fixed_arc
):
"""
Load architecture from `fixed_arc` and apply to model.
...
...
@@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc):
if
isinstance
(
fixed_arc
,
str
):
with
open
(
fixed_arc
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
_encode_tensor
(
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
.
reset
()
return
architecture
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
bf7daa8f
...
...
@@ -7,7 +7,9 @@ from collections import defaultdict
import
numpy
as
np
import
torch
from
nni.nas.pytorch.base_mutator
import
BaseMutator
from
.base_mutator
import
BaseMutator
from
.mutables
import
LayerChoice
,
InputChoice
from
.utils
import
to_list
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -58,7 +60,16 @@ class Mutator(BaseMutator):
dict
A mapping from key of mutables to decisions.
"""
return
self
.
sample_final
()
sampled
=
self
.
sample_final
()
result
=
dict
()
for
mutable
in
self
.
mutables
:
if
not
isinstance
(
mutable
,
(
LayerChoice
,
InputChoice
)):
# not supported as built-in
continue
result
[
mutable
.
key
]
=
self
.
_convert_mutable_decision_to_human_readable
(
mutable
,
sampled
.
pop
(
mutable
.
key
))
if
sampled
:
raise
ValueError
(
"Unexpected keys returned from 'sample_final()': %s"
,
list
(
sampled
.
keys
()))
return
result
def
status
(
self
):
"""
...
...
@@ -159,7 +170,7 @@ class Mutator(BaseMutator):
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
len
(
mutable
),
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
len
(
mutable
))
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
args
,
kwargs
)
for
choice
in
mutable
],
mask
)
out
,
mask
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
args
,
kwargs
)
for
choice
in
mutable
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
...
...
@@ -185,17 +196,41 @@ class Mutator(BaseMutator):
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
out
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,)
for
t
in
tensor_list
],
mask
)
out
,
mask
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,)
for
t
in
tensor_list
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
if
"BoolTensor"
in
mask
.
type
():
"""
Select masked tensors and return a list of tensors.
Parameters
----------
map_fn : function
Convert candidates to target candidates. Can be simply identity.
candidates : list of torch.Tensor
Tensor list to apply the decision on.
mask : list-like object
Can be a list, an numpy array or a tensor (recommended). Needs to
have the same length as ``candidates``.
Returns
-------
tuple of list of torch.Tensor and torch.Tensor
Output and mask.
"""
if
(
isinstance
(
mask
,
list
)
and
len
(
mask
)
>=
1
and
isinstance
(
mask
[
0
],
bool
))
or
\
(
isinstance
(
mask
,
np
.
ndarray
)
and
mask
.
dtype
==
np
.
bool
)
or
\
"BoolTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
elif
"FloatTensor"
in
mask
.
type
():
elif
(
isinstance
(
mask
,
list
)
and
len
(
mask
)
>=
1
and
isinstance
(
mask
[
0
],
(
float
,
int
)))
or
\
(
isinstance
(
mask
,
np
.
ndarray
)
and
mask
.
dtype
in
(
np
.
float32
,
np
.
float64
,
np
.
int32
,
np
.
int64
))
or
\
"FloatTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
else
:
raise
ValueError
(
"Unrecognized mask"
)
return
out
raise
ValueError
(
"Unrecognized mask '%s'"
%
mask
)
if
not
torch
.
is_tensor
(
mask
):
mask
=
torch
.
tensor
(
mask
)
# pylint: disable=not-callable
return
out
,
mask
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
reduction_type
==
"none"
:
...
...
@@ -237,3 +272,37 @@ class Mutator(BaseMutator):
result
=
self
.
_cache
[
mutable
.
key
]
logger
.
debug
(
"Decision %s: %s"
,
mutable
.
key
,
result
)
return
result
def
_convert_mutable_decision_to_human_readable
(
self
,
mutable
,
sampled
):
# Assert the existence of mutable.key in returned architecture.
# Also check if there is anything extra.
multihot_list
=
to_list
(
sampled
)
converted
=
None
# If it's a boolean array, we can do optimization.
if
all
([
t
==
0
or
t
==
1
for
t
in
multihot_list
]):
if
isinstance
(
mutable
,
LayerChoice
):
assert
len
(
multihot_list
)
==
len
(
mutable
),
\
"Results returned from 'sample_final()' (%s: %s) either too short or too long."
\
%
(
mutable
.
key
,
multihot_list
)
# check if all modules have different names and they indeed have names
if
len
(
set
(
mutable
.
names
))
==
len
(
mutable
)
and
not
all
(
d
.
isdigit
()
for
d
in
mutable
.
names
):
converted
=
[
name
for
i
,
name
in
enumerate
(
mutable
.
names
)
if
multihot_list
[
i
]]
else
:
converted
=
[
i
for
i
in
range
(
len
(
multihot_list
))
if
multihot_list
[
i
]]
if
isinstance
(
mutable
,
InputChoice
):
assert
len
(
multihot_list
)
==
mutable
.
n_candidates
,
\
"Results returned from 'sample_final()' (%s: %s) either too short or too long."
\
%
(
mutable
.
key
,
multihot_list
)
# check if all input candidates have different names
if
len
(
set
(
mutable
.
choose_from
))
==
mutable
.
n_candidates
:
converted
=
[
name
for
i
,
name
in
enumerate
(
mutable
.
choose_from
)
if
multihot_list
[
i
]]
else
:
converted
=
[
i
for
i
in
range
(
len
(
multihot_list
))
if
multihot_list
[
i
]]
if
converted
is
not
None
:
# if only one element, then remove the bracket
if
len
(
converted
)
==
1
:
converted
=
converted
[
0
]
else
:
# do nothing
converted
=
multihot_list
return
converted
src/sdk/pynni/nni/nas/pytorch/utils.py
View file @
bf7daa8f
...
...
@@ -4,6 +4,7 @@
import
logging
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
_counter
=
0
...
...
@@ -45,6 +46,16 @@ def to_device(obj, device):
raise
ValueError
(
"'%s' has unsupported type '%s'"
%
(
obj
,
type
(
obj
)))
def
to_list
(
arr
):
if
torch
.
is_tensor
(
arr
):
return
arr
.
cpu
().
numpy
().
tolist
()
if
isinstance
(
arr
,
np
.
ndarray
):
return
arr
.
tolist
()
if
isinstance
(
arr
,
(
list
,
tuple
)):
return
list
(
arr
)
return
arr
class
AverageMeterGroup
:
"""
Average meter group for multiple average meters.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment