Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a7846135
Unverified
Commit
a7846135
authored
Jul 27, 2021
by
Yuge Zhang
Committed by
GitHub
Jul 27, 2021
Browse files
Enable `fixed_arch` on Retiarii (#3972)
parent
08fe2924
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
64 additions
and
14 deletions
+64
-14
docs/en_US/NAS/ApiReference.rst
docs/en_US/NAS/ApiReference.rst
+3
-1
docs/en_US/NAS/OneshotTrainer.rst
docs/en_US/NAS/OneshotTrainer.rst
+7
-1
docs/en_US/NAS/WriteOneshot.rst
docs/en_US/NAS/WriteOneshot.rst
+1
-1
examples/nas/oneshot/darts/model.py
examples/nas/oneshot/darts/model.py
+4
-4
examples/nas/oneshot/darts/retrain.py
examples/nas/oneshot/darts/retrain.py
+3
-3
nni/retiarii/__init__.py
nni/retiarii/__init__.py
+1
-0
nni/retiarii/fixed.py
nni/retiarii/fixed.py
+40
-0
nni/retiarii/oneshot/pytorch/darts.py
nni/retiarii/oneshot/pytorch/darts.py
+5
-4
No files found.
docs/en_US/NAS/ApiReference.rst
View file @
a7846135
...
...
@@ -105,4 +105,6 @@ Retiarii Experiments
Utilities
---------
.. autofunction:: nni.retiarii.serialize
\ No newline at end of file
.. autofunction:: nni.retiarii.serialize
.. autofunction:: nni.retiarii.fixed_arch
docs/en_US/NAS/OneshotTrainer.rst
View file @
a7846135
...
...
@@ -34,4 +34,10 @@ See `API reference <./ApiReference.rst>`__ for detailed usages. Here, we show an
trainer
.
fit
()
final_architecture
=
trainer
.
export
()
**Format of the exported architecture.** TBD.
After
the
searching
is
done
,
we
can
use
the
exported
architecture
to
instantiate
the
full
network
for
retraining
.
Here
is
an
example
:
..
code
-
block
::
python
from
nni
.
retiarii
import
fixed_arch
with
fixed_arch
(
'/path/to/checkpoint.json'
):
model
=
Model
()
docs/en_US/NAS/WriteOneshot.rst
View file @
a7846135
...
...
@@ -16,7 +16,7 @@ A typical example is DartsTrainer, where learnable-parameters are used to combin
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.
key
self.name = layer_choice.
label
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
...
...
examples/nas/oneshot/darts/model.py
View file @
a7846135
...
...
@@ -7,7 +7,7 @@ import torch
import
torch.nn
as
nn
import
ops
from
nni.
nas
.pytorch
import
mutables
from
nni.
retiarii.nn
.pytorch
import
LayerChoice
,
InputChoice
class
AuxiliaryHead
(
nn
.
Module
):
...
...
@@ -45,7 +45,7 @@ class Node(nn.Module):
stride
=
2
if
i
<
num_downsample_connect
else
1
choice_keys
.
append
(
"{}_p{}"
.
format
(
node_id
,
i
))
self
.
ops
.
append
(
mutables
.
LayerChoice
(
OrderedDict
([
LayerChoice
(
OrderedDict
([
(
"maxpool"
,
ops
.
PoolBN
(
'max'
,
channels
,
3
,
stride
,
1
,
affine
=
False
)),
(
"avgpool"
,
ops
.
PoolBN
(
'avg'
,
channels
,
3
,
stride
,
1
,
affine
=
False
)),
(
"skipconnect"
,
nn
.
Identity
()
if
stride
==
1
else
ops
.
FactorizedReduce
(
channels
,
channels
,
affine
=
False
)),
...
...
@@ -53,9 +53,9 @@ class Node(nn.Module):
(
"sepconv5x5"
,
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
)),
(
"dilconv3x3"
,
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
)),
(
"dilconv5x5"
,
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
))
]),
key
=
choice_keys
[
-
1
]))
]),
label
=
choice_keys
[
-
1
]))
self
.
drop_path
=
ops
.
DropPath
()
self
.
input_switch
=
mutables
.
InputChoice
(
choose_from
=
choice_keys
,
n_chosen
=
2
,
key
=
"{}_switch"
.
format
(
node_id
))
self
.
input_switch
=
InputChoice
(
n_candidates
=
len
(
choice_keys
)
,
n_chosen
=
2
,
label
=
"{}_switch"
.
format
(
node_id
))
def
forward
(
self
,
prev_nodes
):
assert
len
(
self
.
ops
)
==
len
(
prev_nodes
)
...
...
examples/nas/oneshot/darts/retrain.py
View file @
a7846135
...
...
@@ -12,8 +12,8 @@ from torch.utils.tensorboard import SummaryWriter
import
datasets
import
utils
from
model
import
CNN
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
from
nni.nas.pytorch.utils
import
AverageMeter
from
nni.retiarii
import
fixed_arch
logger
=
logging
.
getLogger
(
'nni'
)
...
...
@@ -119,8 +119,8 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
,
cutout_length
=
16
)
model
=
CNN
(
32
,
3
,
36
,
10
,
args
.
layers
,
auxiliary
=
True
)
apply_fixed_architecture
(
model
,
args
.
arc_checkpoint
)
with
fixed_arch
(
args
.
arc_checkpoint
):
model
=
CNN
(
32
,
3
,
36
,
10
,
args
.
layers
,
auxiliary
=
True
)
criterion
=
nn
.
CrossEntropyLoss
()
model
.
to
(
device
)
...
...
nni/retiarii/__init__.py
View file @
a7846135
...
...
@@ -4,5 +4,6 @@
from
.operation
import
Operation
from
.graph
import
*
from
.execution
import
*
from
.fixed
import
fixed_arch
from
.mutator
import
*
from
.serializer
import
basic_unit
,
json_dump
,
json_dumps
,
json_load
,
json_loads
,
serialize
,
serialize_cls
,
model_wrapper
nni/retiarii/fixed.py
0 → 100644
View file @
a7846135
import
json
import
logging
from
pathlib
import
Path
from
typing
import
Union
,
Dict
,
Any
from
.utils
import
ContextStack
_logger
=
logging
.
getLogger
(
__name__
)
def
fixed_arch
(
fixed_arch
:
Union
[
str
,
Path
,
Dict
[
str
,
Any
]],
verbose
=
True
):
"""
Load architecture from ``fixed_arch`` and apply to model. This should be used as a context manager. For example,
.. code-block:: python
with fixed_arch('/path/to/export.json'):
model = Model(3, 224, 224)
Parameters
----------
fixed_arc : str, Path or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
ContextStack
Context manager that provides a fixed architecture when creates the model.
"""
if
isinstance
(
fixed_arch
,
(
str
,
Path
)):
with
open
(
fixed_arch
)
as
f
:
fixed_arch
=
json
.
load
(
f
)
if
verbose
:
_logger
.
info
(
f
'Fixed architecture: %s'
,
fixed_arch
)
return
ContextStack
(
'fixed'
,
fixed_arch
)
nni/retiarii/oneshot/pytorch/darts.py
View file @
a7846135
...
...
@@ -3,6 +3,7 @@
import
copy
import
logging
from
collections
import
OrderedDict
import
torch
import
torch.nn
as
nn
...
...
@@ -18,8 +19,8 @@ _logger = logging.getLogger(__name__)
class
DartsLayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
layer_choice
):
super
(
DartsLayerChoice
,
self
).
__init__
()
self
.
name
=
layer_choice
.
key
self
.
op_choices
=
nn
.
ModuleDict
(
layer_choice
.
named_children
(
))
self
.
name
=
layer_choice
.
label
self
.
op_choices
=
nn
.
ModuleDict
(
OrderedDict
([(
name
,
layer_choice
[
name
])
for
name
in
layer_choice
.
names
]
))
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
self
.
op_choices
))
*
1e-3
)
def
forward
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -38,13 +39,13 @@ class DartsLayerChoice(nn.Module):
yield
name
,
p
def
export
(
self
):
return
torch
.
argmax
(
self
.
alpha
).
item
()
return
list
(
self
.
op_choices
.
keys
())[
torch
.
argmax
(
self
.
alpha
).
item
()
]
class
DartsInputChoice
(
nn
.
Module
):
def
__init__
(
self
,
input_choice
):
super
(
DartsInputChoice
,
self
).
__init__
()
self
.
name
=
input_choice
.
key
self
.
name
=
input_choice
.
label
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
input_choice
.
n_candidates
)
*
1e-3
)
self
.
n_chosen
=
input_choice
.
n_chosen
or
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment