Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bf7daa8f
".github/vscode:/vscode.git/clone" did not exist on "7de095e576c66b2b0dafa0c4fd271f936e07ec09"
Unverified
Commit
bf7daa8f
authored
May 11, 2020
by
Yuge Zhang
Committed by
GitHub
May 11, 2020
Browse files
Prettify the export format of NAS trainer (#2389)
parent
af800213
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
29 deletions
+135
-29
docs/en_US/NAS/NasGuide.md
docs/en_US/NAS/NasGuide.md
+14
-3
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+33
-18
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+77
-8
src/sdk/pynni/nni/nas/pytorch/utils.py
src/sdk/pynni/nni/nas/pytorch/utils.py
+11
-0
No files found.
docs/en_US/NAS/NasGuide.md
View file @
bf7daa8f
...
@@ -156,12 +156,23 @@ model = Net()
...
@@ -156,12 +156,23 @@ model = Net()
apply_fixed_architecture
(
model
,
"model_dir/final_architecture.json"
)
apply_fixed_architecture
(
model
,
"model_dir/final_architecture.json"
)
```
```
The JSON is simply a mapping from mutable keys to one-hot or multi-hot representation of choices. For example
The JSON is simply a mapping from mutable keys to choices. Choices can be expressed in:
*
A string: select the candidate with corresponding name.
*
A number: select the candidate with corresponding index.
*
A list of string: select the candidates with corresponding names.
*
A list of number: select the candidates with corresponding indices.
*
A list of boolean values: a multi-hot array.
For example,
```
json
```
json
{
{
"LayerChoice1"
:
[
false
,
true
,
false
,
false
],
"LayerChoice1"
:
"conv5x5"
,
"InputChoice2"
:
[
true
,
true
,
false
]
"LayerChoice2"
:
6
,
"InputChoice3"
:
[
"layer1"
,
"layer3"
],
"InputChoice4"
:
[
1
,
2
],
"InputChoice5"
:
[
false
,
true
,
false
,
false
,
true
]
}
}
```
```
...
...
src/sdk/pynni/nni/nas/pytorch/fixed.py
View file @
bf7daa8f
...
@@ -3,10 +3,9 @@
...
@@ -3,10 +3,9 @@
import
json
import
json
import
torch
from
.mutables
import
InputChoice
,
LayerChoice
,
MutableScope
from
.mutator
import
Mutator
from
nni.nas.pytorch.mutables
import
MutableScope
from
.utils
import
to_list
from
nni.nas.pytorch.mutator
import
Mutator
class
FixedArchitecture
(
Mutator
):
class
FixedArchitecture
(
Mutator
):
...
@@ -17,8 +16,8 @@ class FixedArchitecture(Mutator):
...
@@ -17,8 +16,8 @@ class FixedArchitecture(Mutator):
----------
----------
model : nn.Module
model : nn.Module
A mutable network.
A mutable network.
fixed_arc :
str or
dict
fixed_arc : dict
P
ath to the architecture checkpoint (a string), or p
reloaded architecture object
(a dict)
.
Preloaded architecture object.
strict : bool
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
Force everything that appears in ``fixed_arc`` to be used at least once.
"""
"""
...
@@ -33,6 +32,34 @@ class FixedArchitecture(Mutator):
...
@@ -33,6 +32,34 @@ class FixedArchitecture(Mutator):
raise
RuntimeError
(
"Unexpected keys found in fixed architecture: {}."
.
format
(
fixed_arc_keys
-
mutable_keys
))
raise
RuntimeError
(
"Unexpected keys found in fixed architecture: {}."
.
format
(
fixed_arc_keys
-
mutable_keys
))
if
mutable_keys
-
fixed_arc_keys
:
if
mutable_keys
-
fixed_arc_keys
:
raise
RuntimeError
(
"Missing keys in fixed architecture: {}."
.
format
(
mutable_keys
-
fixed_arc_keys
))
raise
RuntimeError
(
"Missing keys in fixed architecture: {}."
.
format
(
mutable_keys
-
fixed_arc_keys
))
self
.
_fixed_arc
=
self
.
_from_human_readable_architecture
(
self
.
_fixed_arc
)
def
_from_human_readable_architecture
(
self
,
human_arc
):
# convert from an exported architecture
result_arc
=
{
k
:
to_list
(
v
)
for
k
,
v
in
human_arc
.
items
()}
# there could be tensors, numpy arrays, etc.
# First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"},
# which means {"op1": [0, ]} ir {"op1": ["conv", ]}
result_arc
=
{
k
:
v
if
isinstance
(
v
,
list
)
else
[
v
]
for
k
,
v
in
result_arc
.
items
()}
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format.
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true].
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length.
for
mutable
in
self
.
mutables
:
if
mutable
.
key
not
in
result_arc
:
continue
# skip silently
choice_arr
=
result_arc
[
mutable
.
key
]
if
all
(
isinstance
(
v
,
bool
)
for
v
in
choice_arr
)
or
all
(
isinstance
(
v
,
float
)
for
v
in
choice_arr
):
if
(
isinstance
(
mutable
,
LayerChoice
)
and
len
(
mutable
)
==
len
(
choice_arr
))
or
\
(
isinstance
(
mutable
,
InputChoice
)
and
mutable
.
n_candidates
==
len
(
choice_arr
)):
# multihot, do nothing
continue
if
isinstance
(
mutable
,
LayerChoice
):
choice_arr
=
[
mutable
.
names
.
index
(
val
)
if
isinstance
(
val
,
str
)
else
val
for
val
in
choice_arr
]
choice_arr
=
[
i
in
choice_arr
for
i
in
range
(
len
(
mutable
))]
elif
isinstance
(
mutable
,
InputChoice
):
choice_arr
=
[
mutable
.
choose_from
.
index
(
val
)
if
isinstance
(
val
,
str
)
else
val
for
val
in
choice_arr
]
choice_arr
=
[
i
in
choice_arr
for
i
in
range
(
mutable
.
n_candidates
)]
result_arc
[
mutable
.
key
]
=
choice_arr
return
result_arc
def
sample_search
(
self
):
def
sample_search
(
self
):
"""
"""
...
@@ -47,17 +74,6 @@ class FixedArchitecture(Mutator):
...
@@ -47,17 +74,6 @@ class FixedArchitecture(Mutator):
return
self
.
_fixed_arc
return
self
.
_fixed_arc
def
_encode_tensor
(
data
):
if
isinstance
(
data
,
list
):
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
else
:
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
)
# pylint: disable=not-callable
if
isinstance
(
data
,
dict
):
return
{
k
:
_encode_tensor
(
v
)
for
k
,
v
in
data
.
items
()}
return
data
def
apply_fixed_architecture
(
model
,
fixed_arc
):
def
apply_fixed_architecture
(
model
,
fixed_arc
):
"""
"""
Load architecture from `fixed_arc` and apply to model.
Load architecture from `fixed_arc` and apply to model.
...
@@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc):
...
@@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc):
if
isinstance
(
fixed_arc
,
str
):
if
isinstance
(
fixed_arc
,
str
):
with
open
(
fixed_arc
)
as
f
:
with
open
(
fixed_arc
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
_encode_tensor
(
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
.
reset
()
architecture
.
reset
()
return
architecture
return
architecture
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
bf7daa8f
...
@@ -7,7 +7,9 @@ from collections import defaultdict
...
@@ -7,7 +7,9 @@ from collections import defaultdict
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
nni.nas.pytorch.base_mutator
import
BaseMutator
from
.base_mutator
import
BaseMutator
from
.mutables
import
LayerChoice
,
InputChoice
from
.utils
import
to_list
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -58,7 +60,16 @@ class Mutator(BaseMutator):
...
@@ -58,7 +60,16 @@ class Mutator(BaseMutator):
dict
dict
A mapping from key of mutables to decisions.
A mapping from key of mutables to decisions.
"""
"""
return
self
.
sample_final
()
sampled
=
self
.
sample_final
()
result
=
dict
()
for
mutable
in
self
.
mutables
:
if
not
isinstance
(
mutable
,
(
LayerChoice
,
InputChoice
)):
# not supported as built-in
continue
result
[
mutable
.
key
]
=
self
.
_convert_mutable_decision_to_human_readable
(
mutable
,
sampled
.
pop
(
mutable
.
key
))
if
sampled
:
raise
ValueError
(
"Unexpected keys returned from 'sample_final()': %s"
,
list
(
sampled
.
keys
()))
return
result
def
status
(
self
):
def
status
(
self
):
"""
"""
...
@@ -159,7 +170,7 @@ class Mutator(BaseMutator):
...
@@ -159,7 +170,7 @@ class Mutator(BaseMutator):
mask
=
self
.
_get_decision
(
mutable
)
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
len
(
mutable
),
\
assert
len
(
mask
)
==
len
(
mutable
),
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
len
(
mutable
))
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
len
(
mutable
))
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
args
,
kwargs
)
for
choice
in
mutable
],
mask
)
out
,
mask
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
args
,
kwargs
)
for
choice
in
mutable
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
...
@@ -185,17 +196,41 @@ class Mutator(BaseMutator):
...
@@ -185,17 +196,41 @@ class Mutator(BaseMutator):
mask
=
self
.
_get_decision
(
mutable
)
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
out
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,)
for
t
in
tensor_list
],
mask
)
out
,
mask
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,)
for
t
in
tensor_list
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
if
"BoolTensor"
in
mask
.
type
():
"""
Select masked tensors and return a list of tensors.
Parameters
----------
map_fn : function
Convert candidates to target candidates. Can be simply identity.
candidates : list of torch.Tensor
Tensor list to apply the decision on.
mask : list-like object
Can be a list, an numpy array or a tensor (recommended). Needs to
have the same length as ``candidates``.
Returns
-------
tuple of list of torch.Tensor and torch.Tensor
Output and mask.
"""
if
(
isinstance
(
mask
,
list
)
and
len
(
mask
)
>=
1
and
isinstance
(
mask
[
0
],
bool
))
or
\
(
isinstance
(
mask
,
np
.
ndarray
)
and
mask
.
dtype
==
np
.
bool
)
or
\
"BoolTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
elif
"FloatTensor"
in
mask
.
type
():
elif
(
isinstance
(
mask
,
list
)
and
len
(
mask
)
>=
1
and
isinstance
(
mask
[
0
],
(
float
,
int
)))
or
\
(
isinstance
(
mask
,
np
.
ndarray
)
and
mask
.
dtype
in
(
np
.
float32
,
np
.
float64
,
np
.
int32
,
np
.
int64
))
or
\
"FloatTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
else
:
else
:
raise
ValueError
(
"Unrecognized mask"
)
raise
ValueError
(
"Unrecognized mask '%s'"
%
mask
)
return
out
if
not
torch
.
is_tensor
(
mask
):
mask
=
torch
.
tensor
(
mask
)
# pylint: disable=not-callable
return
out
,
mask
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
reduction_type
==
"none"
:
if
reduction_type
==
"none"
:
...
@@ -237,3 +272,37 @@ class Mutator(BaseMutator):
...
@@ -237,3 +272,37 @@ class Mutator(BaseMutator):
result
=
self
.
_cache
[
mutable
.
key
]
result
=
self
.
_cache
[
mutable
.
key
]
logger
.
debug
(
"Decision %s: %s"
,
mutable
.
key
,
result
)
logger
.
debug
(
"Decision %s: %s"
,
mutable
.
key
,
result
)
return
result
return
result
def
_convert_mutable_decision_to_human_readable
(
self
,
mutable
,
sampled
):
# Assert the existence of mutable.key in returned architecture.
# Also check if there is anything extra.
multihot_list
=
to_list
(
sampled
)
converted
=
None
# If it's a boolean array, we can do optimization.
if
all
([
t
==
0
or
t
==
1
for
t
in
multihot_list
]):
if
isinstance
(
mutable
,
LayerChoice
):
assert
len
(
multihot_list
)
==
len
(
mutable
),
\
"Results returned from 'sample_final()' (%s: %s) either too short or too long."
\
%
(
mutable
.
key
,
multihot_list
)
# check if all modules have different names and they indeed have names
if
len
(
set
(
mutable
.
names
))
==
len
(
mutable
)
and
not
all
(
d
.
isdigit
()
for
d
in
mutable
.
names
):
converted
=
[
name
for
i
,
name
in
enumerate
(
mutable
.
names
)
if
multihot_list
[
i
]]
else
:
converted
=
[
i
for
i
in
range
(
len
(
multihot_list
))
if
multihot_list
[
i
]]
if
isinstance
(
mutable
,
InputChoice
):
assert
len
(
multihot_list
)
==
mutable
.
n_candidates
,
\
"Results returned from 'sample_final()' (%s: %s) either too short or too long."
\
%
(
mutable
.
key
,
multihot_list
)
# check if all input candidates have different names
if
len
(
set
(
mutable
.
choose_from
))
==
mutable
.
n_candidates
:
converted
=
[
name
for
i
,
name
in
enumerate
(
mutable
.
choose_from
)
if
multihot_list
[
i
]]
else
:
converted
=
[
i
for
i
in
range
(
len
(
multihot_list
))
if
multihot_list
[
i
]]
if
converted
is
not
None
:
# if only one element, then remove the bracket
if
len
(
converted
)
==
1
:
converted
=
converted
[
0
]
else
:
# do nothing
converted
=
multihot_list
return
converted
src/sdk/pynni/nni/nas/pytorch/utils.py
View file @
bf7daa8f
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
logging
import
logging
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
import
torch
_counter
=
0
_counter
=
0
...
@@ -45,6 +46,16 @@ def to_device(obj, device):
...
@@ -45,6 +46,16 @@ def to_device(obj, device):
raise
ValueError
(
"'%s' has unsupported type '%s'"
%
(
obj
,
type
(
obj
)))
raise
ValueError
(
"'%s' has unsupported type '%s'"
%
(
obj
,
type
(
obj
)))
def
to_list
(
arr
):
if
torch
.
is_tensor
(
arr
):
return
arr
.
cpu
().
numpy
().
tolist
()
if
isinstance
(
arr
,
np
.
ndarray
):
return
arr
.
tolist
()
if
isinstance
(
arr
,
(
list
,
tuple
)):
return
list
(
arr
)
return
arr
class
AverageMeterGroup
:
class
AverageMeterGroup
:
"""
"""
Average meter group for multiple average meters.
Average meter group for multiple average meters.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment