Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a7846135
Unverified
Commit
a7846135
authored
Jul 27, 2021
by
Yuge Zhang
Committed by
GitHub
Jul 27, 2021
Browse files
Enable `fixed_arch` on Retiarii (#3972)
parent
08fe2924
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
64 additions
and
14 deletions
+64
-14
docs/en_US/NAS/ApiReference.rst
docs/en_US/NAS/ApiReference.rst
+3
-1
docs/en_US/NAS/OneshotTrainer.rst
docs/en_US/NAS/OneshotTrainer.rst
+7
-1
docs/en_US/NAS/WriteOneshot.rst
docs/en_US/NAS/WriteOneshot.rst
+1
-1
examples/nas/oneshot/darts/model.py
examples/nas/oneshot/darts/model.py
+4
-4
examples/nas/oneshot/darts/retrain.py
examples/nas/oneshot/darts/retrain.py
+3
-3
nni/retiarii/__init__.py
nni/retiarii/__init__.py
+1
-0
nni/retiarii/fixed.py
nni/retiarii/fixed.py
+40
-0
nni/retiarii/oneshot/pytorch/darts.py
nni/retiarii/oneshot/pytorch/darts.py
+5
-4
No files found.
docs/en_US/NAS/ApiReference.rst
View file @
a7846135
...
...
@@ -106,3 +106,5 @@ Utilities
---------
.. autofunction:: nni.retiarii.serialize
.. autofunction:: nni.retiarii.fixed_arch
docs/en_US/NAS/OneshotTrainer.rst
View file @
a7846135
...
...
@@ -34,4 +34,10 @@ See `API reference <./ApiReference.rst>`__ for detailed usages. Here, we show an
trainer
.
fit
()
final_architecture
=
trainer
.
export
()
**Format of the exported architecture.** TBD.
After
the
searching
is
done
,
we
can
use
the
exported
architecture
to
instantiate
the
full
network
for
retraining
.
Here
is
an
example
:
..
code
-
block
::
python
from
nni
.
retiarii
import
fixed_arch
with
fixed_arch
(
'/path/to/checkpoint.json'
):
model
=
Model
()
docs/en_US/NAS/WriteOneshot.rst
View file @
a7846135
...
...
@@ -16,7 +16,7 @@ A typical example is DartsTrainer, where learnable-parameters are used to combin
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.
key
self.name = layer_choice.
label
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
...
...
examples/nas/oneshot/darts/model.py
View file @
a7846135
...
...
@@ -7,7 +7,7 @@ import torch
import
torch.nn
as
nn
import
ops
from
nni.
nas
.pytorch
import
mutables
from
nni.
retiarii.nn
.pytorch
import
LayerChoice
,
InputChoice
class
AuxiliaryHead
(
nn
.
Module
):
...
...
@@ -45,7 +45,7 @@ class Node(nn.Module):
stride
=
2
if
i
<
num_downsample_connect
else
1
choice_keys
.
append
(
"{}_p{}"
.
format
(
node_id
,
i
))
self
.
ops
.
append
(
mutables
.
LayerChoice
(
OrderedDict
([
LayerChoice
(
OrderedDict
([
(
"maxpool"
,
ops
.
PoolBN
(
'max'
,
channels
,
3
,
stride
,
1
,
affine
=
False
)),
(
"avgpool"
,
ops
.
PoolBN
(
'avg'
,
channels
,
3
,
stride
,
1
,
affine
=
False
)),
(
"skipconnect"
,
nn
.
Identity
()
if
stride
==
1
else
ops
.
FactorizedReduce
(
channels
,
channels
,
affine
=
False
)),
...
...
@@ -53,9 +53,9 @@ class Node(nn.Module):
(
"sepconv5x5"
,
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
)),
(
"dilconv3x3"
,
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
)),
(
"dilconv5x5"
,
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
))
]),
key
=
choice_keys
[
-
1
]))
]),
label
=
choice_keys
[
-
1
]))
self
.
drop_path
=
ops
.
DropPath
()
self
.
input_switch
=
mutables
.
InputChoice
(
choose_from
=
choice_keys
,
n_chosen
=
2
,
key
=
"{}_switch"
.
format
(
node_id
))
self
.
input_switch
=
InputChoice
(
n_candidates
=
len
(
choice_keys
)
,
n_chosen
=
2
,
label
=
"{}_switch"
.
format
(
node_id
))
def
forward
(
self
,
prev_nodes
):
assert
len
(
self
.
ops
)
==
len
(
prev_nodes
)
...
...
examples/nas/oneshot/darts/retrain.py
View file @
a7846135
...
...
@@ -12,8 +12,8 @@ from torch.utils.tensorboard import SummaryWriter
import
datasets
import
utils
from
model
import
CNN
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
from
nni.nas.pytorch.utils
import
AverageMeter
from
nni.retiarii
import
fixed_arch
logger
=
logging
.
getLogger
(
'nni'
)
...
...
@@ -119,8 +119,8 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
,
cutout_length
=
16
)
with
fixed_arch
(
args
.
arc_checkpoint
):
model
=
CNN
(
32
,
3
,
36
,
10
,
args
.
layers
,
auxiliary
=
True
)
apply_fixed_architecture
(
model
,
args
.
arc_checkpoint
)
criterion
=
nn
.
CrossEntropyLoss
()
model
.
to
(
device
)
...
...
nni/retiarii/__init__.py
View file @
a7846135
...
...
@@ -4,5 +4,6 @@
from
.operation
import
Operation
from
.graph
import
*
from
.execution
import
*
from
.fixed
import
fixed_arch
from
.mutator
import
*
from
.serializer
import
basic_unit
,
json_dump
,
json_dumps
,
json_load
,
json_loads
,
serialize
,
serialize_cls
,
model_wrapper
nni/retiarii/fixed.py
0 → 100644
View file @
a7846135
import
json
import
logging
from
pathlib
import
Path
from
typing
import
Union
,
Dict
,
Any
from
.utils
import
ContextStack
_logger
=
logging
.
getLogger
(
__name__
)
def
fixed_arch
(
fixed_arch
:
Union
[
str
,
Path
,
Dict
[
str
,
Any
]],
verbose
=
True
):
"""
Load architecture from ``fixed_arch`` and apply to model. This should be used as a context manager. For example,
.. code-block:: python
with fixed_arch('/path/to/export.json'):
model = Model(3, 224, 224)
Parameters
----------
fixed_arc : str, Path or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
ContextStack
Context manager that provides a fixed architecture when creates the model.
"""
if
isinstance
(
fixed_arch
,
(
str
,
Path
)):
with
open
(
fixed_arch
)
as
f
:
fixed_arch
=
json
.
load
(
f
)
if
verbose
:
_logger
.
info
(
f
'Fixed architecture: %s'
,
fixed_arch
)
return
ContextStack
(
'fixed'
,
fixed_arch
)
nni/retiarii/oneshot/pytorch/darts.py
View file @
a7846135
...
...
@@ -3,6 +3,7 @@
import
copy
import
logging
from
collections
import
OrderedDict
import
torch
import
torch.nn
as
nn
...
...
@@ -18,8 +19,8 @@ _logger = logging.getLogger(__name__)
class
DartsLayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
layer_choice
):
super
(
DartsLayerChoice
,
self
).
__init__
()
self
.
name
=
layer_choice
.
key
self
.
op_choices
=
nn
.
ModuleDict
(
layer_choice
.
named_children
(
))
self
.
name
=
layer_choice
.
label
self
.
op_choices
=
nn
.
ModuleDict
(
OrderedDict
([(
name
,
layer_choice
[
name
])
for
name
in
layer_choice
.
names
]
))
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
self
.
op_choices
))
*
1e-3
)
def
forward
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -38,13 +39,13 @@ class DartsLayerChoice(nn.Module):
yield
name
,
p
def
export
(
self
):
return
torch
.
argmax
(
self
.
alpha
).
item
()
return
list
(
self
.
op_choices
.
keys
())[
torch
.
argmax
(
self
.
alpha
).
item
()
]
class
DartsInputChoice
(
nn
.
Module
):
def
__init__
(
self
,
input_choice
):
super
(
DartsInputChoice
,
self
).
__init__
()
self
.
name
=
input_choice
.
key
self
.
name
=
input_choice
.
label
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
input_choice
.
n_candidates
)
*
1e-3
)
self
.
n_chosen
=
input_choice
.
n_chosen
or
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment