Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bf7daa8f
Unverified
Commit
bf7daa8f
authored
May 11, 2020
by
Yuge Zhang
Committed by
GitHub
May 11, 2020
Browse files
Prettify the export format of NAS trainer (#2389)
parent
af800213
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
29 deletions
+135
-29
docs/en_US/NAS/NasGuide.md
docs/en_US/NAS/NasGuide.md
+14
-3
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+33
-18
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+77
-8
src/sdk/pynni/nni/nas/pytorch/utils.py
src/sdk/pynni/nni/nas/pytorch/utils.py
+11
-0
No files found.
docs/en_US/NAS/NasGuide.md
View file @
bf7daa8f
...
@@ -156,12 +156,23 @@ model = Net()
...
@@ -156,12 +156,23 @@ model = Net()
apply_fixed_architecture
(
model
,
"model_dir/final_architecture.json"
)
apply_fixed_architecture
(
model
,
"model_dir/final_architecture.json"
)
```
```
The JSON is simply a mapping from mutable keys to one-hot or multi-hot representation of choices. For example
The JSON is simply a mapping from mutable keys to choices. Choices can be expressed in:
*
A string: select the candidate with corresponding name.
*
A number: select the candidate with corresponding index.
*
A list of string: select the candidates with corresponding names.
*
A list of number: select the candidates with corresponding indices.
*
A list of boolean values: a multi-hot array.
For example,
```
json
```
json
{
{
"LayerChoice1"
:
[
false
,
true
,
false
,
false
],
"LayerChoice1"
:
"conv5x5"
,
"InputChoice2"
:
[
true
,
true
,
false
]
"LayerChoice2"
:
6
,
"InputChoice3"
:
[
"layer1"
,
"layer3"
],
"InputChoice4"
:
[
1
,
2
],
"InputChoice5"
:
[
false
,
true
,
false
,
false
,
true
]
}
}
```
```
...
...
src/sdk/pynni/nni/nas/pytorch/fixed.py
View file @
bf7daa8f
...
@@ -3,10 +3,9 @@
...
@@ -3,10 +3,9 @@
import
json
import
json
import
torch
from
.mutables
import
InputChoice
,
LayerChoice
,
MutableScope
from
.mutator
import
Mutator
from
nni.nas.pytorch.mutables
import
MutableScope
from
.utils
import
to_list
from
nni.nas.pytorch.mutator
import
Mutator
class
FixedArchitecture
(
Mutator
):
class
FixedArchitecture
(
Mutator
):
...
@@ -17,8 +16,8 @@ class FixedArchitecture(Mutator):
...
@@ -17,8 +16,8 @@ class FixedArchitecture(Mutator):
----------
----------
model : nn.Module
model : nn.Module
A mutable network.
A mutable network.
fixed_arc :
str or
dict
fixed_arc : dict
P
ath to the architecture checkpoint (a string), or p
reloaded architecture object
(a dict)
.
Preloaded architecture object.
strict : bool
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
Force everything that appears in ``fixed_arc`` to be used at least once.
"""
"""
...
@@ -33,6 +32,34 @@ class FixedArchitecture(Mutator):
...
@@ -33,6 +32,34 @@ class FixedArchitecture(Mutator):
raise
RuntimeError
(
"Unexpected keys found in fixed architecture: {}."
.
format
(
fixed_arc_keys
-
mutable_keys
))
raise
RuntimeError
(
"Unexpected keys found in fixed architecture: {}."
.
format
(
fixed_arc_keys
-
mutable_keys
))
if
mutable_keys
-
fixed_arc_keys
:
if
mutable_keys
-
fixed_arc_keys
:
raise
RuntimeError
(
"Missing keys in fixed architecture: {}."
.
format
(
mutable_keys
-
fixed_arc_keys
))
raise
RuntimeError
(
"Missing keys in fixed architecture: {}."
.
format
(
mutable_keys
-
fixed_arc_keys
))
self
.
_fixed_arc
=
self
.
_from_human_readable_architecture
(
self
.
_fixed_arc
)
def
_from_human_readable_architecture
(
self
,
human_arc
):
# convert from an exported architecture
result_arc
=
{
k
:
to_list
(
v
)
for
k
,
v
in
human_arc
.
items
()}
# there could be tensors, numpy arrays, etc.
# First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"},
# which means {"op1": [0, ]} ir {"op1": ["conv", ]}
result_arc
=
{
k
:
v
if
isinstance
(
v
,
list
)
else
[
v
]
for
k
,
v
in
result_arc
.
items
()}
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format.
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true].
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length.
for
mutable
in
self
.
mutables
:
if
mutable
.
key
not
in
result_arc
:
continue
# skip silently
choice_arr
=
result_arc
[
mutable
.
key
]
if
all
(
isinstance
(
v
,
bool
)
for
v
in
choice_arr
)
or
all
(
isinstance
(
v
,
float
)
for
v
in
choice_arr
):
if
(
isinstance
(
mutable
,
LayerChoice
)
and
len
(
mutable
)
==
len
(
choice_arr
))
or
\
(
isinstance
(
mutable
,
InputChoice
)
and
mutable
.
n_candidates
==
len
(
choice_arr
)):
# multihot, do nothing
continue
if
isinstance
(
mutable
,
LayerChoice
):
choice_arr
=
[
mutable
.
names
.
index
(
val
)
if
isinstance
(
val
,
str
)
else
val
for
val
in
choice_arr
]
choice_arr
=
[
i
in
choice_arr
for
i
in
range
(
len
(
mutable
))]
elif
isinstance
(
mutable
,
InputChoice
):
choice_arr
=
[
mutable
.
choose_from
.
index
(
val
)
if
isinstance
(
val
,
str
)
else
val
for
val
in
choice_arr
]
choice_arr
=
[
i
in
choice_arr
for
i
in
range
(
mutable
.
n_candidates
)]
result_arc
[
mutable
.
key
]
=
choice_arr
return
result_arc
def
sample_search
(
self
):
def
sample_search
(
self
):
"""
"""
...
@@ -47,17 +74,6 @@ class FixedArchitecture(Mutator):
...
@@ -47,17 +74,6 @@ class FixedArchitecture(Mutator):
return
self
.
_fixed_arc
return
self
.
_fixed_arc
def
_encode_tensor
(
data
):
if
isinstance
(
data
,
list
):
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
else
:
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
)
# pylint: disable=not-callable
if
isinstance
(
data
,
dict
):
return
{
k
:
_encode_tensor
(
v
)
for
k
,
v
in
data
.
items
()}
return
data
def
apply_fixed_architecture
(
model
,
fixed_arc
):
def
apply_fixed_architecture
(
model
,
fixed_arc
):
"""
"""
Load architecture from `fixed_arc` and apply to model.
Load architecture from `fixed_arc` and apply to model.
...
@@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc):
...
@@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc):
if
isinstance
(
fixed_arc
,
str
):
if
isinstance
(
fixed_arc
,
str
):
with
open
(
fixed_arc
)
as
f
:
with
open
(
fixed_arc
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
_encode_tensor
(
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
.
reset
()
architecture
.
reset
()
return
architecture
return
architecture
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
bf7daa8f
...
@@ -7,7 +7,9 @@ from collections import defaultdict
...
@@ -7,7 +7,9 @@ from collections import defaultdict
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
nni.nas.pytorch.base_mutator
import
BaseMutator
from
.base_mutator
import
BaseMutator
from
.mutables
import
LayerChoice
,
InputChoice
from
.utils
import
to_list
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -58,7 +60,16 @@ class Mutator(BaseMutator):
...
@@ -58,7 +60,16 @@ class Mutator(BaseMutator):
dict
dict
A mapping from key of mutables to decisions.
A mapping from key of mutables to decisions.
"""
"""
return
self
.
sample_final
()
sampled
=
self
.
sample_final
()
result
=
dict
()
for
mutable
in
self
.
mutables
:
if
not
isinstance
(
mutable
,
(
LayerChoice
,
InputChoice
)):
# not supported as built-in
continue
result
[
mutable
.
key
]
=
self
.
_convert_mutable_decision_to_human_readable
(
mutable
,
sampled
.
pop
(
mutable
.
key
))
if
sampled
:
raise
ValueError
(
"Unexpected keys returned from 'sample_final()': %s"
,
list
(
sampled
.
keys
()))
return
result
def
status
(
self
):
def
status
(
self
):
"""
"""
...
@@ -159,7 +170,7 @@ class Mutator(BaseMutator):
...
@@ -159,7 +170,7 @@ class Mutator(BaseMutator):
mask
=
self
.
_get_decision
(
mutable
)
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
len
(
mutable
),
\
assert
len
(
mask
)
==
len
(
mutable
),
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
len
(
mutable
))
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
len
(
mutable
))
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
args
,
kwargs
)
for
choice
in
mutable
],
mask
)
out
,
mask
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
args
,
kwargs
)
for
choice
in
mutable
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
...
@@ -185,17 +196,41 @@ class Mutator(BaseMutator):
...
@@ -185,17 +196,41 @@ class Mutator(BaseMutator):
mask
=
self
.
_get_decision
(
mutable
)
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
out
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,)
for
t
in
tensor_list
],
mask
)
out
,
mask
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,)
for
t
in
tensor_list
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
if
"BoolTensor"
in
mask
.
type
():
"""
Select masked tensors and return a list of tensors.
Parameters
----------
map_fn : function
Convert candidates to target candidates. Can be simply identity.
candidates : list of torch.Tensor
Tensor list to apply the decision on.
mask : list-like object
Can be a list, an numpy array or a tensor (recommended). Needs to
have the same length as ``candidates``.
Returns
-------
tuple of list of torch.Tensor and torch.Tensor
Output and mask.
"""
if
(
isinstance
(
mask
,
list
)
and
len
(
mask
)
>=
1
and
isinstance
(
mask
[
0
],
bool
))
or
\
(
isinstance
(
mask
,
np
.
ndarray
)
and
mask
.
dtype
==
np
.
bool
)
or
\
"BoolTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
elif
"FloatTensor"
in
mask
.
type
():
elif
(
isinstance
(
mask
,
list
)
and
len
(
mask
)
>=
1
and
isinstance
(
mask
[
0
],
(
float
,
int
)))
or
\
(
isinstance
(
mask
,
np
.
ndarray
)
and
mask
.
dtype
in
(
np
.
float32
,
np
.
float64
,
np
.
int32
,
np
.
int64
))
or
\
"FloatTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
else
:
else
:
raise
ValueError
(
"Unrecognized mask"
)
raise
ValueError
(
"Unrecognized mask '%s'"
%
mask
)
return
out
if
not
torch
.
is_tensor
(
mask
):
mask
=
torch
.
tensor
(
mask
)
# pylint: disable=not-callable
return
out
,
mask
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
reduction_type
==
"none"
:
if
reduction_type
==
"none"
:
...
@@ -237,3 +272,37 @@ class Mutator(BaseMutator):
...
@@ -237,3 +272,37 @@ class Mutator(BaseMutator):
result
=
self
.
_cache
[
mutable
.
key
]
result
=
self
.
_cache
[
mutable
.
key
]
logger
.
debug
(
"Decision %s: %s"
,
mutable
.
key
,
result
)
logger
.
debug
(
"Decision %s: %s"
,
mutable
.
key
,
result
)
return
result
return
result
def
_convert_mutable_decision_to_human_readable
(
self
,
mutable
,
sampled
):
# Assert the existence of mutable.key in returned architecture.
# Also check if there is anything extra.
multihot_list
=
to_list
(
sampled
)
converted
=
None
# If it's a boolean array, we can do optimization.
if
all
([
t
==
0
or
t
==
1
for
t
in
multihot_list
]):
if
isinstance
(
mutable
,
LayerChoice
):
assert
len
(
multihot_list
)
==
len
(
mutable
),
\
"Results returned from 'sample_final()' (%s: %s) either too short or too long."
\
%
(
mutable
.
key
,
multihot_list
)
# check if all modules have different names and they indeed have names
if
len
(
set
(
mutable
.
names
))
==
len
(
mutable
)
and
not
all
(
d
.
isdigit
()
for
d
in
mutable
.
names
):
converted
=
[
name
for
i
,
name
in
enumerate
(
mutable
.
names
)
if
multihot_list
[
i
]]
else
:
converted
=
[
i
for
i
in
range
(
len
(
multihot_list
))
if
multihot_list
[
i
]]
if
isinstance
(
mutable
,
InputChoice
):
assert
len
(
multihot_list
)
==
mutable
.
n_candidates
,
\
"Results returned from 'sample_final()' (%s: %s) either too short or too long."
\
%
(
mutable
.
key
,
multihot_list
)
# check if all input candidates have different names
if
len
(
set
(
mutable
.
choose_from
))
==
mutable
.
n_candidates
:
converted
=
[
name
for
i
,
name
in
enumerate
(
mutable
.
choose_from
)
if
multihot_list
[
i
]]
else
:
converted
=
[
i
for
i
in
range
(
len
(
multihot_list
))
if
multihot_list
[
i
]]
if
converted
is
not
None
:
# if only one element, then remove the bracket
if
len
(
converted
)
==
1
:
converted
=
converted
[
0
]
else
:
# do nothing
converted
=
multihot_list
return
converted
src/sdk/pynni/nni/nas/pytorch/utils.py
View file @
bf7daa8f
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
logging
import
logging
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
import
torch
_counter
=
0
_counter
=
0
...
@@ -45,6 +46,16 @@ def to_device(obj, device):
...
@@ -45,6 +46,16 @@ def to_device(obj, device):
raise
ValueError
(
"'%s' has unsupported type '%s'"
%
(
obj
,
type
(
obj
)))
raise
ValueError
(
"'%s' has unsupported type '%s'"
%
(
obj
,
type
(
obj
)))
def
to_list
(
arr
):
if
torch
.
is_tensor
(
arr
):
return
arr
.
cpu
().
numpy
().
tolist
()
if
isinstance
(
arr
,
np
.
ndarray
):
return
arr
.
tolist
()
if
isinstance
(
arr
,
(
list
,
tuple
)):
return
list
(
arr
)
return
arr
class
AverageMeterGroup
:
class
AverageMeterGroup
:
"""
"""
Average meter group for multiple average meters.
Average meter group for multiple average meters.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment