Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
867871b2
Unverified
Commit
867871b2
authored
Jul 27, 2022
by
Yuge Zhang
Committed by
GitHub
Jul 27, 2022
Browse files
Promote Retiarii to NAS (step 1) - move files (#5020)
parent
481aa292
Changes
137
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
0 additions
and
745 deletions
+0
-745
nni/nas/pytorch/utils.py
nni/nas/pytorch/utils.py
+0
-210
nni/nas/strategy/base.py
nni/nas/strategy/base.py
+0
-0
nni/nas/strategy/bruteforce.py
nni/nas/strategy/bruteforce.py
+0
-0
nni/nas/strategy/debug.py
nni/nas/strategy/debug.py
+0
-0
nni/nas/strategy/evolution.py
nni/nas/strategy/evolution.py
+0
-0
nni/nas/strategy/hpo.py
nni/nas/strategy/hpo.py
+0
-0
nni/nas/strategy/oneshot.py
nni/nas/strategy/oneshot.py
+0
-0
nni/nas/strategy/rl.py
nni/nas/strategy/rl.py
+0
-0
nni/nas/strategy/utils.py
nni/nas/strategy/utils.py
+0
-0
nni/nas/tensorflow/__init__.py
nni/nas/tensorflow/__init__.py
+0
-0
nni/nas/tensorflow/base_mutator.py
nni/nas/tensorflow/base_mutator.py
+0
-73
nni/nas/tensorflow/mutables.py
nni/nas/tensorflow/mutables.py
+0
-144
nni/nas/tensorflow/mutator.py
nni/nas/tensorflow/mutator.py
+0
-83
nni/nas/tensorflow/utils.py
nni/nas/tensorflow/utils.py
+0
-93
nni/nas/utils/misc.py
nni/nas/utils/misc.py
+0
-0
nni/nas/utils/serializer.py
nni/nas/utils/serializer.py
+0
-0
nni/retiarii/oneshot/pytorch/enas.py
nni/retiarii/oneshot/pytorch/enas.py
+0
-142
No files found.
nni/nas/pytorch/utils.py
deleted
100644 → 0
View file @
481aa292
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
_counter
=
0
_logger
=
logging
.
getLogger
(
__name__
)
def
global_mutable_counting
():
"""
A program level counter starting from 1.
"""
global
_counter
_counter
+=
1
return
_counter
def
_reset_global_mutable_counting
():
"""
Reset the global mutable counting to count from 1. Useful when defining multiple models with default keys.
"""
global
_counter
_counter
=
0
def
to_device
(
obj
,
device
):
"""
Move a tensor, tuple, list, or dict onto device.
"""
if
torch
.
is_tensor
(
obj
):
return
obj
.
to
(
device
)
if
isinstance
(
obj
,
tuple
):
return
tuple
(
to_device
(
t
,
device
)
for
t
in
obj
)
if
isinstance
(
obj
,
list
):
return
[
to_device
(
t
,
device
)
for
t
in
obj
]
if
isinstance
(
obj
,
dict
):
return
{
k
:
to_device
(
v
,
device
)
for
k
,
v
in
obj
.
items
()}
if
isinstance
(
obj
,
(
int
,
float
,
str
)):
return
obj
raise
ValueError
(
"'%s' has unsupported type '%s'"
%
(
obj
,
type
(
obj
)))
def
to_list
(
arr
):
if
torch
.
is_tensor
(
arr
):
return
arr
.
cpu
().
numpy
().
tolist
()
if
isinstance
(
arr
,
np
.
ndarray
):
return
arr
.
tolist
()
if
isinstance
(
arr
,
(
list
,
tuple
)):
return
list
(
arr
)
return
arr
class
AverageMeterGroup
:
"""
Average meter group for multiple average meters.
"""
def
__init__
(
self
):
self
.
meters
=
OrderedDict
()
def
update
(
self
,
data
):
"""
Update the meter group with a dict of metrics.
Non-exist average meters will be automatically created.
"""
for
k
,
v
in
data
.
items
():
if
k
not
in
self
.
meters
:
self
.
meters
[
k
]
=
AverageMeter
(
k
,
":4f"
)
self
.
meters
[
k
].
update
(
v
)
def
__getattr__
(
self
,
item
):
return
self
.
meters
[
item
]
def
__getitem__
(
self
,
item
):
return
self
.
meters
[
item
]
def
__str__
(
self
):
return
" "
.
join
(
str
(
v
)
for
v
in
self
.
meters
.
values
())
def
summary
(
self
):
"""
Return a summary string of group data.
"""
return
" "
.
join
(
v
.
summary
()
for
v
in
self
.
meters
.
values
())
class
AverageMeter
:
"""
Computes and stores the average and current value.
Parameters
----------
name : str
Name to display.
fmt : str
Format string to print the values.
"""
def
__init__
(
self
,
name
,
fmt
=
':f'
):
self
.
name
=
name
self
.
fmt
=
fmt
self
.
reset
()
def
reset
(
self
):
"""
Reset the meter.
"""
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
"""
Update with value and weight.
Parameters
----------
val : float or int
The new value to be accounted in.
n : int
The weight of the new value.
"""
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
def
__str__
(
self
):
fmtstr
=
'{name} {val'
+
self
.
fmt
+
'} ({avg'
+
self
.
fmt
+
'})'
return
fmtstr
.
format
(
**
self
.
__dict__
)
def
summary
(
self
):
fmtstr
=
'{name}: {avg'
+
self
.
fmt
+
'}'
return
fmtstr
.
format
(
**
self
.
__dict__
)
class
StructuredMutableTreeNode
:
"""
A structured representation of a search space.
A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`.
This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet,
the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a
``Mutable`` (other than ``MutableScope``).
Parameters
----------
mutable : nni.nas.pytorch.mutables.Mutable
The mutable that current node is linked with.
"""
def
__init__
(
self
,
mutable
):
self
.
mutable
=
mutable
self
.
children
=
[]
def
add_child
(
self
,
mutable
):
"""
Add a tree node to the children list of current node.
"""
self
.
children
.
append
(
StructuredMutableTreeNode
(
mutable
))
return
self
.
children
[
-
1
]
def
type
(
self
):
"""
Return the ``type`` of mutable content.
"""
return
type
(
self
.
mutable
)
def
__iter__
(
self
):
return
self
.
traverse
()
def
traverse
(
self
,
order
=
"pre"
,
deduplicate
=
True
,
memo
=
None
):
"""
Return a generator that generates a list of mutables in this tree.
Parameters
----------
order : str
pre or post. If pre, current mutable is yield before children. Otherwise after.
deduplicate : bool
If true, mutables with the same key will not appear after the first appearance.
memo : dict
An auxiliary dict that memorize keys seen before, so that deduplication is possible.
Returns
-------
generator of Mutable
"""
if
memo
is
None
:
memo
=
set
()
assert
order
in
[
"pre"
,
"post"
]
if
order
==
"pre"
:
if
self
.
mutable
is
not
None
:
if
not
deduplicate
or
self
.
mutable
.
key
not
in
memo
:
memo
.
add
(
self
.
mutable
.
key
)
yield
self
.
mutable
for
child
in
self
.
children
:
for
m
in
child
.
traverse
(
order
=
order
,
deduplicate
=
deduplicate
,
memo
=
memo
):
yield
m
if
order
==
"post"
:
if
self
.
mutable
is
not
None
:
if
not
deduplicate
or
self
.
mutable
.
key
not
in
memo
:
memo
.
add
(
self
.
mutable
.
key
)
yield
self
.
mutable
nni/
retiarii
/strategy/base.py
→
nni/
nas
/strategy/base.py
View file @
867871b2
File moved
nni/
retiarii
/strategy/bruteforce.py
→
nni/
nas
/strategy/bruteforce.py
View file @
867871b2
File moved
nni/
retiarii
/strategy/
local_debug_strategy
.py
→
nni/
nas
/strategy/
debug
.py
View file @
867871b2
File moved
nni/
retiarii
/strategy/evolution.py
→
nni/
nas
/strategy/evolution.py
View file @
867871b2
File moved
nni/
retiarii
/strategy/
tpe_strategy
.py
→
nni/
nas
/strategy/
hpo
.py
View file @
867871b2
File moved
nni/
retiarii
/strategy/oneshot.py
→
nni/
nas
/strategy/oneshot.py
View file @
867871b2
File moved
nni/
retiarii
/strategy/rl.py
→
nni/
nas
/strategy/rl.py
View file @
867871b2
File moved
nni/
retiarii
/strategy/utils.py
→
nni/
nas
/strategy/utils.py
View file @
867871b2
File moved
nni/nas/tensorflow/__init__.py
deleted
100644 → 0
View file @
481aa292
nni/nas/tensorflow/base_mutator.py
deleted
100644 → 0
View file @
481aa292
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
tensorflow.keras
import
Model
from
.mutables
import
Mutable
,
MutableScope
,
InputChoice
from
.utils
import
StructuredMutableTreeNode
class
BaseMutator
(
Model
):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
__dict__
[
'model'
]
=
model
self
.
_structured_mutables
=
self
.
_parse_search_space
(
self
.
model
)
def
_parse_search_space
(
self
,
module
,
root
=
None
,
prefix
=
''
,
memo
=
None
,
nested_detection
=
None
):
if
memo
is
None
:
memo
=
set
()
if
root
is
None
:
root
=
StructuredMutableTreeNode
(
None
)
if
module
not
in
memo
:
memo
.
add
(
module
)
if
isinstance
(
module
,
Mutable
):
if
nested_detection
is
not
None
:
raise
RuntimeError
(
'Cannot have nested search space. Error at {} in {}'
.
format
(
module
,
nested_detection
))
module
.
name
=
prefix
module
.
set_mutator
(
self
)
root
=
root
.
add_child
(
module
)
if
not
isinstance
(
module
,
MutableScope
):
nested_detection
=
module
if
isinstance
(
module
,
InputChoice
):
for
k
in
module
.
choose_from
:
if
k
!=
InputChoice
.
NO_KEY
and
k
not
in
[
m
.
key
for
m
in
memo
if
isinstance
(
m
,
Mutable
)]:
raise
RuntimeError
(
'"{}" required by "{}" not found in keys that appeared before, and is not NO_KEY.'
.
format
(
k
,
module
.
key
))
for
submodule
in
module
.
layers
:
if
not
isinstance
(
submodule
,
Model
):
continue
submodule_prefix
=
prefix
+
(
'.'
if
prefix
else
''
)
+
submodule
.
name
self
.
_parse_search_space
(
submodule
,
root
,
submodule_prefix
,
memo
=
memo
,
nested_detection
=
nested_detection
)
return
root
@
property
def
mutables
(
self
):
return
self
.
_structured_mutables
def
undedup_mutables
(
self
):
return
self
.
_structured_mutables
.
traverse
(
deduplicate
=
False
)
def
call
(
self
,
*
inputs
):
raise
RuntimeError
(
'Call is undefined for mutators.'
)
def
__setattr__
(
self
,
name
,
value
):
if
name
==
'model'
:
raise
AttributeError
(
"Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include your network, as it will include all parameters in model into the mutator."
)
return
super
().
__setattr__
(
name
,
value
)
def
enter_mutable_scope
(
self
,
mutable_scope
):
pass
def
exit_mutable_scope
(
self
,
mutable_scope
):
pass
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
raise
NotImplementedError
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
raise
NotImplementedError
def
export
(
self
):
raise
NotImplementedError
nni/nas/tensorflow/mutables.py
deleted
100644 → 0
View file @
481aa292
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
collections
import
OrderedDict
from
tensorflow.keras
import
Model
from
.utils
import
global_mutable_counting
_logger
=
logging
.
getLogger
(
__name__
)
class
Mutable
(
Model
):
def
__init__
(
self
,
key
=
None
):
super
().
__init__
()
if
key
is
None
:
self
.
_key
=
'{}_{}'
.
format
(
type
(
self
).
__name__
,
global_mutable_counting
())
elif
isinstance
(
key
,
str
):
self
.
_key
=
key
else
:
self
.
_key
=
str
(
key
)
_logger
.
warning
(
'Key "%s" is not string, converted to string.'
,
key
)
self
.
init_hook
=
None
self
.
forward_hook
=
None
def
__deepcopy__
(
self
,
memodict
=
None
):
raise
NotImplementedError
(
"Deep copy doesn't work for mutables."
)
def
set_mutator
(
self
,
mutator
):
if
hasattr
(
self
,
'mutator'
):
raise
RuntimeError
(
'`set_mutator is called more than once. '
'Did you parse the search space multiple times? '
'Or did you apply multiple fixed architectures?'
)
self
.
mutator
=
mutator
def
call
(
self
,
*
inputs
):
raise
NotImplementedError
(
'Method `call` of Mutable must be overridden'
)
def
build
(
self
,
input_shape
):
self
.
_check_built
()
@
property
def
key
(
self
):
return
self
.
_key
@
property
def
name
(
self
):
return
self
.
_name
if
hasattr
(
self
,
'_name'
)
else
self
.
_key
@
name
.
setter
def
name
(
self
,
name
):
self
.
_name
=
name
def
_check_built
(
self
):
if
not
hasattr
(
self
,
'mutator'
):
raise
ValueError
(
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
def
__repr__
(
self
):
return
'{} ({})'
.
format
(
self
.
name
,
self
.
key
)
class
MutableScope
(
Mutable
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
try
:
self
.
mutator
.
enter_mutable_scope
(
self
)
return
super
().
__call__
(
*
args
,
**
kwargs
)
finally
:
self
.
mutator
.
exit_mutable_scope
(
self
)
class
LayerChoice
(
Mutable
):
def
__init__
(
self
,
op_candidates
,
reduction
=
'sum'
,
return_mask
=
False
,
key
=
None
):
super
().
__init__
(
key
=
key
)
self
.
names
=
[]
if
isinstance
(
op_candidates
,
OrderedDict
):
for
name
in
op_candidates
:
assert
name
not
in
[
"length"
,
"reduction"
,
"return_mask"
,
"_key"
,
"key"
,
"names"
],
\
"Please don't use a reserved name '{}' for your module."
.
format
(
name
)
self
.
names
.
append
(
name
)
elif
isinstance
(
op_candidates
,
list
):
for
i
,
_
in
enumerate
(
op_candidates
):
self
.
names
.
append
(
str
(
i
))
else
:
raise
TypeError
(
"Unsupported op_candidates type: {}"
.
format
(
type
(
op_candidates
)))
self
.
length
=
len
(
op_candidates
)
self
.
choices
=
op_candidates
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
def
call
(
self
,
*
inputs
):
out
,
mask
=
self
.
mutator
.
on_forward_layer_choice
(
self
,
*
inputs
)
if
self
.
return_mask
:
return
out
,
mask
return
out
def
build
(
self
,
input_shape
):
self
.
_check_built
()
for
op
in
self
.
choices
:
op
.
build
(
input_shape
)
def
__len__
(
self
):
return
len
(
self
.
choices
)
class
InputChoice
(
Mutable
):
NO_KEY
=
''
def
__init__
(
self
,
n_candidates
=
None
,
choose_from
=
None
,
n_chosen
=
None
,
reduction
=
'sum'
,
return_mask
=
False
,
key
=
None
):
super
().
__init__
(
key
=
key
)
assert
n_candidates
is
not
None
or
choose_from
is
not
None
,
\
'At least one of `n_candidates` and `choose_from` must be not None.'
if
choose_from
is
not
None
and
n_candidates
is
None
:
n_candidates
=
len
(
choose_from
)
elif
choose_from
is
None
and
n_candidates
is
not
None
:
choose_from
=
[
self
.
NO_KEY
]
*
n_candidates
assert
n_candidates
==
len
(
choose_from
),
'Number of candidates must be equal to the length of `choose_from`.'
assert
n_candidates
>
0
,
'Number of candidates must be greater than 0.'
assert
n_chosen
is
None
or
0
<=
n_chosen
<=
n_candidates
,
\
'Expected selected number must be None or no more than number of candidates.'
self
.
n_candidates
=
n_candidates
self
.
choose_from
=
choose_from
.
copy
()
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
def
call
(
self
,
optional_inputs
):
optional_input_list
=
optional_inputs
if
isinstance
(
optional_inputs
,
dict
):
optional_input_list
=
[
optional_inputs
[
tag
]
for
tag
in
self
.
choose_from
]
assert
isinstance
(
optional_input_list
,
list
),
\
'Optional input list must be a list, not a {}.'
.
format
(
type
(
optional_input_list
))
assert
len
(
optional_inputs
)
==
self
.
n_candidates
,
\
'Length of the input list must be equal to number of candidates.'
out
,
mask
=
self
.
mutator
.
on_forward_input_choice
(
self
,
optional_input_list
)
if
self
.
return_mask
:
return
out
,
mask
return
out
nni/nas/tensorflow/mutator.py
deleted
100644 → 0
View file @
481aa292
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
tensorflow
as
tf
from
.base_mutator
import
BaseMutator
_logger
=
logging
.
getLogger
(
__name__
)
class
Mutator
(
BaseMutator
):
def
__init__
(
self
,
model
):
super
().
__init__
(
model
)
self
.
_cache
=
{}
def
sample_search
(
self
):
raise
NotImplementedError
(
'Method `sample_search` must be overridden'
)
def
sample_final
(
self
):
raise
NotImplementedError
(
'Method `sample_final` must be overriden for exporting'
)
def
reset
(
self
):
self
.
_cache
=
self
.
sample_search
()
def
export
(
self
):
return
self
.
sample_final
()
# TODO: status
# TODO: graph
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
len
(
mutable
),
\
'Invalid mask, expected {} to be of length {}.'
.
format
(
mask
,
len
(
mutable
))
out
=
self
.
_select_with_mask
(
lambda
choice
:
choice
(
*
inputs
),
mutable
.
choices
,
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
'Invalid mask, expected {} to be of length {}.'
.
format
(
mask
,
mutable
.
n_candidates
)
out
=
self
.
_select_with_mask
(
lambda
tensor
:
tensor
,
tensor_list
,
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
if
mask
.
dtype
.
is_bool
:
out
=
[
map_fn
(
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
elif
mask
.
dtype
.
is_floating
:
out
=
[
map_fn
(
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
else
:
raise
ValueError
(
'Unrecognized mask, dtype is {}'
.
format
(
mask
.
dtype
.
name
))
return
out
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
reduction_type
==
'none'
:
return
tensor_list
if
not
tensor_list
:
return
None
if
len
(
tensor_list
)
==
1
:
return
tensor_list
[
0
]
if
reduction_type
==
'sum'
:
return
sum
(
tensor_list
)
if
reduction_type
==
'mean'
:
return
sum
(
tensor_list
)
/
len
(
tensor_list
)
if
reduction_type
==
'concat'
:
image_data_format
=
tf
.
keras
.
backend
.
image_data_format
()
if
image_data_format
==
"channels_first"
:
axis
=
0
else
:
axis
=
-
1
return
tf
.
concat
(
tensor_list
,
axis
=
axis
)
# pylint: disable=E1120,E1123
# pylint issue #3613
raise
ValueError
(
'Unrecognized reduction policy: "{}'
.
format
(
reduction_type
))
def
_get_decision
(
self
,
mutable
):
if
mutable
.
key
not
in
self
.
_cache
:
raise
ValueError
(
'"{}" not found in decision cache.'
.
format
(
mutable
.
key
))
result
=
self
.
_cache
[
mutable
.
key
]
_logger
.
debug
(
'Decision %s: %s'
,
mutable
.
key
,
result
)
return
result
nni/nas/tensorflow/utils.py
deleted
100644 → 0
View file @
481aa292
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
tensorflow
as
tf
_counter
=
0
def
global_mutable_counting
():
global
_counter
_counter
+=
1
return
_counter
class
AverageMeter
:
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
):
self
.
val
=
val
self
.
sum
+=
val
self
.
count
+=
1
self
.
avg
=
self
.
sum
/
self
.
count
def
__str__
(
self
):
return
'{name} {val:4f} ({avg:4f})'
.
format
(
**
self
.
__dict__
)
def
summary
(
self
):
return
'{name}: {avg:4f}'
.
format
(
**
self
.
__dict__
)
class
AverageMeterGroup
:
def
__init__
(
self
):
self
.
meters
=
{}
def
update
(
self
,
data
):
for
k
,
v
in
data
.
items
():
if
k
not
in
self
.
meters
:
self
.
meters
[
k
]
=
AverageMeter
(
k
)
self
.
meters
[
k
].
update
(
v
)
def
__str__
(
self
):
return
' '
.
join
(
str
(
v
)
for
v
in
self
.
meters
.
values
())
def
summary
(
self
):
return
' '
.
join
(
v
.
summary
()
for
v
in
self
.
meters
.
values
())
class
StructuredMutableTreeNode
:
def
__init__
(
self
,
mutable
):
self
.
mutable
=
mutable
self
.
children
=
[]
def
add_child
(
self
,
mutable
):
self
.
children
.
append
(
StructuredMutableTreeNode
(
mutable
))
return
self
.
children
[
-
1
]
def
type
(
self
):
return
type
(
self
.
mutable
)
def
__iter__
(
self
):
return
self
.
traverse
()
def
traverse
(
self
,
order
=
"pre"
,
deduplicate
=
True
,
memo
=
None
):
if
memo
is
None
:
memo
=
set
()
assert
order
in
[
"pre"
,
"post"
]
if
order
==
"pre"
:
if
self
.
mutable
is
not
None
:
if
not
deduplicate
or
self
.
mutable
.
key
not
in
memo
:
memo
.
add
(
self
.
mutable
.
key
)
yield
self
.
mutable
for
child
in
self
.
children
:
for
m
in
child
.
traverse
(
order
=
order
,
deduplicate
=
deduplicate
,
memo
=
memo
):
yield
m
if
order
==
"post"
:
if
self
.
mutable
is
not
None
:
if
not
deduplicate
or
self
.
mutable
.
key
not
in
memo
:
memo
.
add
(
self
.
mutable
.
key
)
yield
self
.
mutable
def
fill_zero_grads
(
grads
,
weights
):
ret
=
[]
for
grad
,
weight
in
zip
(
grads
,
weights
):
if
grad
is
not
None
:
ret
.
append
(
grad
)
else
:
ret
.
append
(
tf
.
zeros_like
(
weight
))
return
ret
nni/
retiarii
/utils.py
→
nni/
nas
/utils
/misc
.py
View file @
867871b2
File moved
nni/
retiarii
/serializer.py
→
nni/
nas/utils
/serializer.py
View file @
867871b2
File moved
nni/retiarii/oneshot/pytorch/enas.py
View file @
867871b2
...
@@ -18,148 +18,6 @@ from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
...
@@ -18,148 +18,6 @@ from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
class
StackedLSTMCell
(
nn
.
Module
):
def
__init__
(
self
,
layers
,
size
,
bias
):
super
().
__init__
()
self
.
lstm_num_layers
=
layers
self
.
lstm_modules
=
nn
.
ModuleList
([
nn
.
LSTMCell
(
size
,
size
,
bias
=
bias
)
for
_
in
range
(
self
.
lstm_num_layers
)])
def
forward
(
self
,
inputs
,
hidden
):
prev_h
,
prev_c
=
hidden
next_h
,
next_c
=
[],
[]
for
i
,
m
in
enumerate
(
self
.
lstm_modules
):
curr_h
,
curr_c
=
m
(
inputs
,
(
prev_h
[
i
],
prev_c
[
i
]))
next_c
.
append
(
curr_c
)
next_h
.
append
(
curr_h
)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs
=
curr_h
[
-
1
].
view
(
1
,
-
1
)
return
next_h
,
next_c
class
ReinforceField
:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def
__init__
(
self
,
name
,
total
,
choose_one
):
self
.
name
=
name
self
.
total
=
total
self
.
choose_one
=
choose_one
def
__repr__
(
self
):
return
f
'ReinforceField(name=
{
self
.
name
}
, total=
{
self
.
total
}
, choose_one=
{
self
.
choose_one
}
)'
class
ReinforceController
(
nn
.
Module
):
"""
A controller that mutates the graph with RL.
Parameters
----------
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect (chosen by InputChoice) will appear.
If the chosen number of inputs is away from the ``skip_connect``, there will be
a sample skip penalty which is a KL divergence added.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def
__init__
(
self
,
fields
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
skip_target
=
0.4
,
temperature
=
None
,
entropy_reduction
=
'sum'
):
super
(
ReinforceController
,
self
).
__init__
()
self
.
fields
=
fields
self
.
lstm_size
=
lstm_size
self
.
lstm_num_layers
=
lstm_num_layers
self
.
tanh_constant
=
tanh_constant
self
.
temperature
=
temperature
self
.
skip_target
=
skip_target
self
.
lstm
=
StackedLSTMCell
(
self
.
lstm_num_layers
,
self
.
lstm_size
,
False
)
self
.
attn_anchor
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
tensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
# pylint: disable=not-callable
requires_grad
=
False
)
assert
entropy_reduction
in
[
'sum'
,
'mean'
],
'Entropy reduction must be one of sum and mean.'
self
.
entropy_reduction
=
torch
.
sum
if
entropy_reduction
==
'sum'
else
torch
.
mean
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
'none'
)
self
.
soft
=
nn
.
ModuleDict
({
field
.
name
:
nn
.
Linear
(
self
.
lstm_size
,
field
.
total
,
bias
=
False
)
for
field
in
fields
})
self
.
embedding
=
nn
.
ModuleDict
({
field
.
name
:
nn
.
Embedding
(
field
.
total
,
self
.
lstm_size
)
for
field
in
fields
})
def
resample
(
self
):
self
.
_initialize
()
result
=
dict
()
for
field
in
self
.
fields
:
result
[
field
.
name
]
=
self
.
_sample_single
(
field
)
return
result
def
_initialize
(
self
):
self
.
_inputs
=
self
.
g_emb
.
data
self
.
_c
=
[
torch
.
zeros
((
1
,
self
.
lstm_size
),
dtype
=
self
.
_inputs
.
dtype
,
device
=
self
.
_inputs
.
device
)
for
_
in
range
(
self
.
lstm_num_layers
)]
self
.
_h
=
[
torch
.
zeros
((
1
,
self
.
lstm_size
),
dtype
=
self
.
_inputs
.
dtype
,
device
=
self
.
_inputs
.
device
)
for
_
in
range
(
self
.
lstm_num_layers
)]
self
.
sample_log_prob
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
self
.
sample_entropy
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
self
.
sample_skip_penalty
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
def
_lstm_next_step
(
self
):
self
.
_h
,
self
.
_c
=
self
.
lstm
(
self
.
_inputs
,
(
self
.
_h
,
self
.
_c
))
def
_sample_single
(
self
,
field
):
self
.
_lstm_next_step
()
logit
=
self
.
soft
[
field
.
name
](
self
.
_h
[
-
1
])
if
self
.
temperature
is
not
None
:
logit
/=
self
.
temperature
if
self
.
tanh_constant
is
not
None
:
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
if
field
.
choose_one
:
sampled
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
sampled
)
self
.
_inputs
=
self
.
embedding
[
field
.
name
](
sampled
)
else
:
logit
=
logit
.
view
(
-
1
,
1
)
logit
=
torch
.
cat
([
-
logit
,
logit
],
1
)
# pylint: disable=invalid-unary-operand-type
sampled
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
skip_prob
=
torch
.
sigmoid
(
logit
)
kl
=
torch
.
sum
(
skip_prob
*
torch
.
log
(
skip_prob
/
self
.
skip_targets
))
self
.
sample_skip_penalty
+=
kl
log_prob
=
self
.
cross_entropy_loss
(
logit
,
sampled
)
sampled
=
sampled
.
nonzero
().
view
(
-
1
)
if
sampled
.
sum
().
item
():
self
.
_inputs
=
(
torch
.
sum
(
self
.
embedding
[
field
.
name
](
sampled
.
view
(
-
1
)),
0
)
/
(
1.
+
torch
.
sum
(
sampled
))).
unsqueeze
(
0
)
else
:
self
.
_inputs
=
torch
.
zeros
(
1
,
self
.
lstm_size
,
device
=
self
.
embedding
[
field
.
name
].
weight
.
device
)
# type: ignore
sampled
=
sampled
.
detach
().
cpu
().
numpy
().
tolist
()
self
.
sample_log_prob
+=
self
.
entropy_reduction
(
log_prob
)
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
self
.
sample_entropy
+=
self
.
entropy_reduction
(
entropy
)
if
len
(
sampled
)
==
1
:
sampled
=
sampled
[
0
]
return
sampled
class
EnasTrainer
(
BaseOneShotTrainer
):
class
EnasTrainer
(
BaseOneShotTrainer
):
"""
"""
ENAS trainer.
ENAS trainer.
...
...
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment