Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
77e91e8b
Commit
77e91e8b
authored
Nov 21, 2019
by
Yuge Zhang
Committed by
Chi Song
Nov 21, 2019
Browse files
Extract controller from mutator to make offline decisions (#1758)
parent
9dda5370
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
601 additions
and
330 deletions
+601
-330
examples/nas/.gitignore
examples/nas/.gitignore
+2
-1
examples/nas/darts/model.py
examples/nas/darts/model.py
+21
-14
examples/nas/darts/retrain.py
examples/nas/darts/retrain.py
+143
-0
examples/nas/darts/search.py
examples/nas/darts/search.py
+2
-2
examples/nas/enas/macro.py
examples/nas/enas/macro.py
+10
-10
examples/nas/enas/micro.py
examples/nas/enas/micro.py
+17
-16
examples/nas/enas/search.py
examples/nas/enas/search.py
+4
-3
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
+98
-41
src/sdk/pynni/nni/nas/pytorch/base_trainer.py
src/sdk/pynni/nni/nas/pytorch/base_trainer.py
+5
-1
src/sdk/pynni/nni/nas/pytorch/callbacks.py
src/sdk/pynni/nni/nas/pytorch/callbacks.py
+1
-19
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
+1
-2
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
+35
-23
src/sdk/pynni/nni/nas/pytorch/darts/scope.py
src/sdk/pynni/nni/nas/pytorch/darts/scope.py
+0
-11
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+16
-13
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
+35
-21
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
+12
-10
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+55
-35
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+76
-32
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+65
-73
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
+3
-3
No files found.
examples/nas/.gitignore
View file @
77e91e8b
data
data
checkpoints
examples/nas/darts/model.py
View file @
77e91e8b
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
ops
import
ops
from
nni.nas.pytorch
import
mutables
,
darts
from
nni.nas.pytorch
import
mutables
class
AuxiliaryHead
(
nn
.
Module
):
class
AuxiliaryHead
(
nn
.
Module
):
...
@@ -31,12 +31,14 @@ class AuxiliaryHead(nn.Module):
...
@@ -31,12 +31,14 @@ class AuxiliaryHead(nn.Module):
return
logits
return
logits
class
Node
(
darts
.
DartsNod
e
):
class
Node
(
nn
.
Modul
e
):
def
__init__
(
self
,
node_id
,
num_prev_nodes
,
channels
,
num_downsample_connect
,
drop_path_prob
=
0.
):
def
__init__
(
self
,
node_id
,
num_prev_nodes
,
channels
,
num_downsample_connect
):
super
().
__init__
(
node_id
,
limitation
=
2
)
super
().
__init__
()
self
.
ops
=
nn
.
ModuleList
()
self
.
ops
=
nn
.
ModuleList
()
choice_keys
=
[]
for
i
in
range
(
num_prev_nodes
):
for
i
in
range
(
num_prev_nodes
):
stride
=
2
if
i
<
num_downsample_connect
else
1
stride
=
2
if
i
<
num_downsample_connect
else
1
choice_keys
.
append
(
"{}_p{}"
.
format
(
node_id
,
i
))
self
.
ops
.
append
(
self
.
ops
.
append
(
mutables
.
LayerChoice
(
mutables
.
LayerChoice
(
[
[
...
@@ -48,18 +50,19 @@ class Node(darts.DartsNode):
...
@@ -48,18 +50,19 @@ class Node(darts.DartsNode):
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
),
],
],
key
=
"{}_p{}"
.
format
(
node_id
,
i
)))
key
=
choice_keys
[
-
1
]))
self
.
drop_path
=
ops
.
DropPath_
(
drop_path_prob
)
self
.
drop_path
=
ops
.
DropPath_
()
self
.
input_switch
=
mutables
.
InputChoice
(
choose_from
=
choice_keys
,
n_chosen
=
2
,
key
=
"{}_switch"
.
format
(
node_id
))
def
forward
(
self
,
prev_nodes
):
def
forward
(
self
,
prev_nodes
):
assert
len
(
self
.
ops
)
==
len
(
prev_nodes
)
assert
len
(
self
.
ops
)
==
len
(
prev_nodes
)
out
=
[
op
(
node
)
for
op
,
node
in
zip
(
self
.
ops
,
prev_nodes
)]
out
=
[
op
(
node
)
for
op
,
node
in
zip
(
self
.
ops
,
prev_nodes
)]
return
s
um
(
self
.
drop_path
(
o
)
for
o
in
out
if
o
is
not
None
)
return
s
elf
.
input_switch
(
out
)
class
Cell
(
nn
.
Module
):
class
Cell
(
nn
.
Module
):
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
,
drop_path_prob
=
0.
):
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
super
().
__init__
()
super
().
__init__
()
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
n_nodes
=
n_nodes
self
.
n_nodes
=
n_nodes
...
@@ -74,10 +77,9 @@ class Cell(nn.Module):
...
@@ -74,10 +77,9 @@ class Cell(nn.Module):
# generate dag
# generate dag
self
.
mutable_ops
=
nn
.
ModuleList
()
self
.
mutable_ops
=
nn
.
ModuleList
()
for
depth
in
range
(
self
.
n_nodes
):
for
depth
in
range
(
2
,
self
.
n_nodes
+
2
):
self
.
mutable_ops
.
append
(
Node
(
"r{:d}_n{}"
.
format
(
reduction
,
depth
),
self
.
mutable_ops
.
append
(
Node
(
"{}_n{}"
.
format
(
"reduce"
if
reduction
else
"normal"
,
depth
),
depth
+
2
,
channels
,
2
if
reduction
else
0
,
depth
,
channels
,
2
if
reduction
else
0
))
drop_path_prob
=
drop_path_prob
))
def
forward
(
self
,
s0
,
s1
):
def
forward
(
self
,
s0
,
s1
):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
...
@@ -93,7 +95,7 @@ class Cell(nn.Module):
...
@@ -93,7 +95,7 @@ class Cell(nn.Module):
class
CNN
(
nn
.
Module
):
class
CNN
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
def
__init__
(
self
,
input_size
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
stem_multiplier
=
3
,
auxiliary
=
False
,
drop_path_prob
=
0.
):
stem_multiplier
=
3
,
auxiliary
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -120,7 +122,7 @@ class CNN(nn.Module):
...
@@ -120,7 +122,7 @@ class CNN(nn.Module):
c_cur
*=
2
c_cur
*=
2
reduction
=
True
reduction
=
True
cell
=
Cell
(
n_nodes
,
channels_pp
,
channels_p
,
c_cur
,
reduction_p
,
reduction
,
drop_path_prob
=
drop_path_prob
)
cell
=
Cell
(
n_nodes
,
channels_pp
,
channels_p
,
c_cur
,
reduction_p
,
reduction
)
self
.
cells
.
append
(
cell
)
self
.
cells
.
append
(
cell
)
c_cur_out
=
c_cur
*
n_nodes
c_cur_out
=
c_cur
*
n_nodes
channels_pp
,
channels_p
=
channels_p
,
c_cur_out
channels_pp
,
channels_p
=
channels_p
,
c_cur_out
...
@@ -147,3 +149,8 @@ class CNN(nn.Module):
...
@@ -147,3 +149,8 @@ class CNN(nn.Module):
if
aux_logits
is
not
None
:
if
aux_logits
is
not
None
:
return
logits
,
aux_logits
return
logits
,
aux_logits
return
logits
return
logits
def
drop_path_prob
(
self
,
p
):
for
module
in
self
.
modules
():
if
isinstance
(
module
,
ops
.
DropPath_
):
module
.
p
=
p
examples/nas/darts/retrain.py
0 → 100644
View file @
77e91e8b
import
logging
from
argparse
import
ArgumentParser
import
torch
import
torch.nn
as
nn
import
datasets
import
utils
from
model
import
CNN
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
from
nni.nas.pytorch.utils
import
AverageMeter
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
def
train
(
config
,
train_loader
,
model
,
optimizer
,
criterion
,
epoch
):
top1
=
AverageMeter
(
"top1"
)
top5
=
AverageMeter
(
"top5"
)
losses
=
AverageMeter
(
"losses"
)
cur_step
=
epoch
*
len
(
train_loader
)
cur_lr
=
optimizer
.
param_groups
[
0
][
'lr'
]
logger
.
info
(
"Epoch %d LR %.6f"
,
epoch
,
cur_lr
)
model
.
train
()
for
step
,
(
x
,
y
)
in
enumerate
(
train_loader
):
x
,
y
=
x
.
to
(
device
,
non_blocking
=
True
),
y
.
to
(
device
,
non_blocking
=
True
)
bs
=
x
.
size
(
0
)
optimizer
.
zero_grad
()
logits
,
aux_logits
=
model
(
x
)
loss
=
criterion
(
logits
,
y
)
if
config
.
aux_weight
>
0.
:
loss
+=
config
.
aux_weight
*
criterion
(
aux_logits
,
y
)
loss
.
backward
()
# gradient clipping
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
config
.
grad_clip
)
optimizer
.
step
()
accuracy
=
utils
.
accuracy
(
logits
,
y
,
topk
=
(
1
,
5
))
losses
.
update
(
loss
.
item
(),
bs
)
top1
.
update
(
accuracy
[
"acc1"
],
bs
)
top5
.
update
(
accuracy
[
"acc5"
],
bs
)
if
step
%
config
.
log_frequency
==
0
or
step
==
len
(
train_loader
)
-
1
:
logger
.
info
(
"Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})"
.
format
(
epoch
+
1
,
config
.
epochs
,
step
,
len
(
train_loader
)
-
1
,
losses
=
losses
,
top1
=
top1
,
top5
=
top5
))
cur_step
+=
1
logger
.
info
(
"Train: [{:3d}/{}] Final Prec@1 {:.4%}"
.
format
(
epoch
+
1
,
config
.
epochs
,
top1
.
avg
))
def
validate
(
config
,
valid_loader
,
model
,
criterion
,
epoch
,
cur_step
):
top1
=
AverageMeter
(
"top1"
)
top5
=
AverageMeter
(
"top5"
)
losses
=
AverageMeter
(
"losses"
)
model
.
eval
()
with
torch
.
no_grad
():
for
step
,
(
X
,
y
)
in
enumerate
(
valid_loader
):
X
,
y
=
X
.
to
(
device
,
non_blocking
=
True
),
y
.
to
(
device
,
non_blocking
=
True
)
N
=
X
.
size
(
0
)
logits
=
model
(
X
)
loss
=
criterion
(
logits
,
y
)
accuracy
=
utils
.
accuracy
(
logits
,
y
,
topk
=
(
1
,
5
))
losses
.
update
(
loss
.
item
(),
N
)
top1
.
update
(
accuracy
[
"acc1"
],
N
)
top5
.
update
(
accuracy
[
"acc5"
],
N
)
if
step
%
config
.
log_frequency
==
0
or
step
==
len
(
valid_loader
)
-
1
:
logger
.
info
(
"Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})"
.
format
(
epoch
+
1
,
config
.
epochs
,
step
,
len
(
valid_loader
)
-
1
,
losses
=
losses
,
top1
=
top1
,
top5
=
top5
))
logger
.
info
(
"Valid: [{:3d}/{}] Final Prec@1 {:.4%}"
.
format
(
epoch
+
1
,
config
.
epochs
,
top1
.
avg
))
return
top1
.
avg
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"darts"
)
parser
.
add_argument
(
"--layers"
,
default
=
20
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
96
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--epochs"
,
default
=
600
,
type
=
int
)
parser
.
add_argument
(
"--aux-weight"
,
default
=
0.4
,
type
=
float
)
parser
.
add_argument
(
"--drop-path-prob"
,
default
=
0.2
,
type
=
float
)
parser
.
add_argument
(
"--workers"
,
default
=
4
)
parser
.
add_argument
(
"--grad-clip"
,
default
=
5.
,
type
=
float
)
parser
.
add_argument
(
"--arc-checkpoint"
,
default
=
"./checkpoints/epoch_0.json"
)
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
,
cutout_length
=
16
)
model
=
CNN
(
32
,
3
,
36
,
10
,
args
.
layers
,
auxiliary
=
True
)
apply_fixed_architecture
(
model
,
args
.
arc_checkpoint
,
device
=
device
)
criterion
=
nn
.
CrossEntropyLoss
()
model
.
to
(
device
)
criterion
.
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
0.025
,
momentum
=
0.9
,
weight_decay
=
3.0E-4
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
args
.
epochs
,
eta_min
=
1E-6
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset_train
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
num_workers
=
args
.
workers
,
pin_memory
=
True
)
valid_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset_valid
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
args
.
workers
,
pin_memory
=
True
)
best_top1
=
0.
for
epoch
in
range
(
args
.
epochs
):
drop_prob
=
args
.
drop_path_prob
*
epoch
/
args
.
epochs
model
.
drop_path_prob
(
drop_prob
)
# training
train
(
args
,
train_loader
,
model
,
optimizer
,
criterion
,
epoch
)
# validation
cur_step
=
(
epoch
+
1
)
*
len
(
train_loader
)
top1
=
validate
(
args
,
valid_loader
,
model
,
criterion
,
epoch
,
cur_step
)
best_top1
=
max
(
best_top1
,
top1
)
lr_scheduler
.
step
()
logger
.
info
(
"Final best Prec@1 = {:.4%}"
.
format
(
best_top1
))
examples/nas/darts/search.py
View file @
77e91e8b
...
@@ -13,7 +13,7 @@ from utils import accuracy
...
@@ -13,7 +13,7 @@ from utils import accuracy
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"darts"
)
parser
=
ArgumentParser
(
"darts"
)
parser
.
add_argument
(
"--layers"
,
default
=
8
,
type
=
int
)
parser
.
add_argument
(
"--layers"
,
default
=
8
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
9
6
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
6
4
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--epochs"
,
default
=
50
,
type
=
int
)
parser
.
add_argument
(
"--epochs"
,
default
=
50
,
type
=
int
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -36,4 +36,4 @@ if __name__ == "__main__":
...
@@ -36,4 +36,4 @@ if __name__ == "__main__":
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
log_frequency
=
args
.
log_frequency
,
log_frequency
=
args
.
log_frequency
,
callbacks
=
[
LearningRateScheduler
(
lr_scheduler
),
ArchitectureCheckpoint
(
"./checkpoints"
)])
callbacks
=
[
LearningRateScheduler
(
lr_scheduler
),
ArchitectureCheckpoint
(
"./checkpoints"
)])
trainer
.
train
_and_validate
()
trainer
.
train
()
examples/nas/enas/macro.py
View file @
77e91e8b
...
@@ -6,7 +6,7 @@ from ops import FactorizedReduce, ConvBranch, PoolBranch
...
@@ -6,7 +6,7 @@ from ops import FactorizedReduce, ConvBranch, PoolBranch
class
ENASLayer
(
mutables
.
MutableScope
):
class
ENASLayer
(
mutables
.
MutableScope
):
def
__init__
(
self
,
key
,
num_
prev_la
yer
s
,
in_filters
,
out_filters
):
def
__init__
(
self
,
key
,
prev_la
bel
s
,
in_filters
,
out_filters
):
super
().
__init__
(
key
)
super
().
__init__
(
key
)
self
.
in_filters
=
in_filters
self
.
in_filters
=
in_filters
self
.
out_filters
=
out_filters
self
.
out_filters
=
out_filters
...
@@ -18,16 +18,16 @@ class ENASLayer(mutables.MutableScope):
...
@@ -18,16 +18,16 @@ class ENASLayer(mutables.MutableScope):
PoolBranch
(
'avg'
,
in_filters
,
out_filters
,
3
,
1
,
1
),
PoolBranch
(
'avg'
,
in_filters
,
out_filters
,
3
,
1
,
1
),
PoolBranch
(
'max'
,
in_filters
,
out_filters
,
3
,
1
,
1
)
PoolBranch
(
'max'
,
in_filters
,
out_filters
,
3
,
1
,
1
)
])
])
if
num_
prev_la
yers
>
0
:
if
len
(
prev_la
bels
)
>
0
:
self
.
skipconnect
=
mutables
.
InputChoice
(
num_
prev_la
yer
s
,
n_
selected
=
None
,
reduction
=
"sum"
)
self
.
skipconnect
=
mutables
.
InputChoice
(
choose_from
=
prev_la
bel
s
,
n_
chosen
=
None
,
reduction
=
"sum"
)
else
:
else
:
self
.
skipconnect
=
None
self
.
skipconnect
=
None
self
.
batch_norm
=
nn
.
BatchNorm2d
(
out_filters
,
affine
=
False
)
self
.
batch_norm
=
nn
.
BatchNorm2d
(
out_filters
,
affine
=
False
)
def
forward
(
self
,
prev_layers
,
prev_labels
):
def
forward
(
self
,
prev_layers
):
out
=
self
.
mutable
(
prev_layers
[
-
1
])
out
=
self
.
mutable
(
prev_layers
[
-
1
])
if
self
.
skipconnect
is
not
None
:
if
self
.
skipconnect
is
not
None
:
connection
=
self
.
skipconnect
(
prev_layers
[:
-
1
]
,
tags
=
prev_labels
)
connection
=
self
.
skipconnect
(
prev_layers
[:
-
1
])
if
connection
is
not
None
:
if
connection
is
not
None
:
out
+=
connection
out
+=
connection
return
self
.
batch_norm
(
out
)
return
self
.
batch_norm
(
out
)
...
@@ -53,11 +53,12 @@ class GeneralNetwork(nn.Module):
...
@@ -53,11 +53,12 @@ class GeneralNetwork(nn.Module):
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
=
nn
.
ModuleList
()
self
.
pool_layers
=
nn
.
ModuleList
()
self
.
pool_layers
=
nn
.
ModuleList
()
labels
=
[]
for
layer_id
in
range
(
self
.
num_layers
):
for
layer_id
in
range
(
self
.
num_layers
):
labels
.
append
(
"layer_{}"
.
format
(
layer_id
))
if
layer_id
in
self
.
pool_layers_idx
:
if
layer_id
in
self
.
pool_layers_idx
:
self
.
pool_layers
.
append
(
FactorizedReduce
(
self
.
out_filters
,
self
.
out_filters
))
self
.
pool_layers
.
append
(
FactorizedReduce
(
self
.
out_filters
,
self
.
out_filters
))
self
.
layers
.
append
(
ENASLayer
(
"layer_{}"
.
format
(
layer_id
),
layer_id
,
self
.
layers
.
append
(
ENASLayer
(
labels
[
-
1
],
labels
[:
-
1
],
self
.
out_filters
,
self
.
out_filters
))
self
.
out_filters
,
self
.
out_filters
))
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
dense
=
nn
.
Linear
(
self
.
out_filters
,
self
.
num_classes
)
self
.
dense
=
nn
.
Linear
(
self
.
out_filters
,
self
.
num_classes
)
...
@@ -66,12 +67,11 @@ class GeneralNetwork(nn.Module):
...
@@ -66,12 +67,11 @@ class GeneralNetwork(nn.Module):
bs
=
x
.
size
(
0
)
bs
=
x
.
size
(
0
)
cur
=
self
.
stem
(
x
)
cur
=
self
.
stem
(
x
)
layers
,
labels
=
[
cur
]
,
[]
layers
=
[
cur
]
for
layer_id
in
range
(
self
.
num_layers
):
for
layer_id
in
range
(
self
.
num_layers
):
cur
=
self
.
layers
[
layer_id
](
layers
,
labels
)
cur
=
self
.
layers
[
layer_id
](
layers
)
layers
.
append
(
cur
)
layers
.
append
(
cur
)
labels
.
append
(
self
.
layers
[
layer_id
].
key
)
if
layer_id
in
self
.
pool_layers_idx
:
if
layer_id
in
self
.
pool_layers_idx
:
for
i
,
layer
in
enumerate
(
layers
):
for
i
,
layer
in
enumerate
(
layers
):
layers
[
i
]
=
self
.
pool_layers
[
self
.
pool_layers_idx
.
index
(
layer_id
)](
layer
)
layers
[
i
]
=
self
.
pool_layers
[
self
.
pool_layers_idx
.
index
(
layer_id
)](
layer
)
...
...
examples/nas/enas/micro.py
View file @
77e91e8b
...
@@ -32,9 +32,9 @@ class AuxiliaryHead(nn.Module):
...
@@ -32,9 +32,9 @@ class AuxiliaryHead(nn.Module):
class
Cell
(
nn
.
Module
):
class
Cell
(
nn
.
Module
):
def
__init__
(
self
,
cell_name
,
num_
prev_la
yer
s
,
channels
):
def
__init__
(
self
,
cell_name
,
prev_la
bel
s
,
channels
):
super
().
__init__
()
super
().
__init__
()
self
.
input_choice
=
mutables
.
InputChoice
(
num_
prev_la
yer
s
,
n_
selected
=
1
,
return_mask
=
True
,
self
.
input_choice
=
mutables
.
InputChoice
(
choose_from
=
prev_la
bel
s
,
n_
chosen
=
1
,
return_mask
=
True
,
key
=
cell_name
+
"_input"
)
key
=
cell_name
+
"_input"
)
self
.
op_choice
=
mutables
.
LayerChoice
([
self
.
op_choice
=
mutables
.
LayerChoice
([
SepConvBN
(
channels
,
channels
,
3
,
1
),
SepConvBN
(
channels
,
channels
,
3
,
1
),
...
@@ -44,21 +44,21 @@ class Cell(nn.Module):
...
@@ -44,21 +44,21 @@ class Cell(nn.Module):
nn
.
Identity
()
nn
.
Identity
()
],
key
=
cell_name
+
"_op"
)
],
key
=
cell_name
+
"_op"
)
def
forward
(
self
,
prev_layers
,
prev_labels
):
def
forward
(
self
,
prev_layers
):
chosen_input
,
chosen_mask
=
self
.
input_choice
(
prev_layers
,
tags
=
prev_labels
)
chosen_input
,
chosen_mask
=
self
.
input_choice
(
prev_layers
)
cell_out
=
self
.
op_choice
(
chosen_input
)
cell_out
=
self
.
op_choice
(
chosen_input
)
return
cell_out
,
chosen_mask
return
cell_out
,
chosen_mask
class
Node
(
mutables
.
MutableScope
):
class
Node
(
mutables
.
MutableScope
):
def
__init__
(
self
,
node_name
,
num_
prev_
layer
s
,
channels
):
def
__init__
(
self
,
node_name
,
prev_
node_name
s
,
channels
):
super
().
__init__
(
node_name
)
super
().
__init__
(
node_name
)
self
.
cell_x
=
Cell
(
node_name
+
"_x"
,
num_
prev_
layer
s
,
channels
)
self
.
cell_x
=
Cell
(
node_name
+
"_x"
,
prev_
node_name
s
,
channels
)
self
.
cell_y
=
Cell
(
node_name
+
"_y"
,
num_
prev_
layer
s
,
channels
)
self
.
cell_y
=
Cell
(
node_name
+
"_y"
,
prev_
node_name
s
,
channels
)
def
forward
(
self
,
prev_layers
,
prev_labels
):
def
forward
(
self
,
prev_layers
):
out_x
,
mask_x
=
self
.
cell_x
(
prev_layers
,
prev_labels
)
out_x
,
mask_x
=
self
.
cell_x
(
prev_layers
)
out_y
,
mask_y
=
self
.
cell_y
(
prev_layers
,
prev_labels
)
out_y
,
mask_y
=
self
.
cell_y
(
prev_layers
)
return
out_x
+
out_y
,
mask_x
|
mask_y
return
out_x
+
out_y
,
mask_x
|
mask_y
...
@@ -93,8 +93,11 @@ class ENASLayer(nn.Module):
...
@@ -93,8 +93,11 @@ class ENASLayer(nn.Module):
self
.
num_nodes
=
num_nodes
self
.
num_nodes
=
num_nodes
name_prefix
=
"reduce"
if
reduction
else
"normal"
name_prefix
=
"reduce"
if
reduction
else
"normal"
self
.
nodes
=
nn
.
ModuleList
([
Node
(
"{}_node_{}"
.
format
(
name_prefix
,
i
),
self
.
nodes
=
nn
.
ModuleList
()
i
+
2
,
out_channels
)
for
i
in
range
(
num_nodes
)])
node_labels
=
[
mutables
.
InputChoice
.
NO_KEY
,
mutables
.
InputChoice
.
NO_KEY
]
for
i
in
range
(
num_nodes
):
node_labels
.
append
(
"{}_node_{}"
.
format
(
name_prefix
,
i
))
self
.
nodes
.
append
(
Node
(
node_labels
[
-
1
],
node_labels
[:
-
1
],
out_channels
))
self
.
final_conv_w
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
,
self
.
num_nodes
+
2
,
out_channels
,
1
,
1
),
requires_grad
=
True
)
self
.
final_conv_w
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
,
self
.
num_nodes
+
2
,
out_channels
,
1
,
1
),
requires_grad
=
True
)
self
.
bn
=
nn
.
BatchNorm2d
(
out_channels
,
affine
=
False
)
self
.
bn
=
nn
.
BatchNorm2d
(
out_channels
,
affine
=
False
)
self
.
reset_parameters
()
self
.
reset_parameters
()
...
@@ -106,14 +109,12 @@ class ENASLayer(nn.Module):
...
@@ -106,14 +109,12 @@ class ENASLayer(nn.Module):
pprev_
,
prev_
=
self
.
preproc0
(
pprev
),
self
.
preproc1
(
prev
)
pprev_
,
prev_
=
self
.
preproc0
(
pprev
),
self
.
preproc1
(
prev
)
prev_nodes_out
=
[
pprev_
,
prev_
]
prev_nodes_out
=
[
pprev_
,
prev_
]
prev_nodes_labels
=
[
"prev1"
,
"prev2"
]
nodes_used_mask
=
torch
.
zeros
(
self
.
num_nodes
+
2
,
dtype
=
torch
.
bool
,
device
=
prev
.
device
)
nodes_used_mask
=
torch
.
zeros
(
self
.
num_nodes
+
2
,
dtype
=
torch
.
bool
,
device
=
prev
.
device
)
for
i
in
range
(
self
.
num_nodes
):
for
i
in
range
(
self
.
num_nodes
):
node_out
,
mask
=
self
.
nodes
[
i
](
prev_nodes_out
,
prev_nodes_labels
)
node_out
,
mask
=
self
.
nodes
[
i
](
prev_nodes_out
)
nodes_used_mask
[:
mask
.
size
(
0
)]
|=
mask
nodes_used_mask
[:
mask
.
size
(
0
)]
|=
mask
prev_nodes_out
.
append
(
node_out
)
prev_nodes_out
.
append
(
node_out
)
prev_nodes_labels
.
append
(
self
.
nodes
[
i
].
key
)
unused_nodes
=
torch
.
cat
([
out
for
used
,
out
in
zip
(
nodes_used_mask
,
prev_nodes_out
)
if
not
used
],
1
)
unused_nodes
=
torch
.
cat
([
out
for
used
,
out
in
zip
(
nodes_used_mask
,
prev_nodes_out
)
if
not
used
],
1
)
unused_nodes
=
F
.
relu
(
unused_nodes
)
unused_nodes
=
F
.
relu
(
unused_nodes
)
conv_weight
=
self
.
final_conv_w
[:,
~
nodes_used_mask
,
:,
:,
:]
conv_weight
=
self
.
final_conv_w
[:,
~
nodes_used_mask
,
:,
:,
:]
...
...
examples/nas/enas/search.py
View file @
77e91e8b
...
@@ -13,7 +13,7 @@ from utils import accuracy, reward_accuracy
...
@@ -13,7 +13,7 @@ from utils import accuracy, reward_accuracy
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"enas"
)
parser
=
ArgumentParser
(
"enas"
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
1
0
,
type
=
int
)
parser
.
add_argument
(
"--search-for"
,
choices
=
[
"macro"
,
"micro"
],
default
=
"macro"
)
parser
.
add_argument
(
"--search-for"
,
choices
=
[
"macro"
,
"micro"
],
default
=
"macro"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -43,5 +43,6 @@ if __name__ == "__main__":
...
@@ -43,5 +43,6 @@ if __name__ == "__main__":
num_epochs
=
num_epochs
,
num_epochs
=
num_epochs
,
dataset_train
=
dataset_train
,
dataset_train
=
dataset_train
,
dataset_valid
=
dataset_valid
,
dataset_valid
=
dataset_valid
,
log_frequency
=
args
.
log_frequency
)
log_frequency
=
args
.
log_frequency
,
trainer
.
train_and_validate
()
mutator
=
mutator
)
trainer
.
train
()
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
View file @
77e91e8b
import
logging
import
logging
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.nas.pytorch.mutables
import
Mutable
,
MutableScope
,
InputChoice
from
nni.nas.pytorch.
m
ut
able
s
import
Mutabl
e
from
nni.nas.pytorch.ut
il
s
import
StructuredMutableTreeNod
e
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
BaseMutator
(
nn
.
Module
):
class
BaseMutator
(
nn
.
Module
):
"""
A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing
callbacks that are called in ``forward`` in Mutables.
"""
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
super
().
__init__
()
super
().
__init__
()
self
.
__dict__
[
"model"
]
=
model
self
.
__dict__
[
"model"
]
=
model
self
.
before_parse_search_space
()
self
.
_structured_mutables
=
self
.
_parse_search_space
(
self
.
model
)
self
.
_parse_search_space
()
self
.
after_parse_search_space
()
def
before_parse_search_space
(
self
):
pass
def
after_parse_search_space
(
self
):
pass
def
_parse_search_space
(
self
):
for
name
,
mutable
,
_
in
self
.
named_mutables
(
distinct
=
False
):
mutable
.
name
=
name
mutable
.
set_mutator
(
self
)
def
named_mutables
(
self
,
root
=
None
,
distinct
=
True
):
def
_parse_search_space
(
self
,
module
,
root
=
None
,
prefix
=
""
,
memo
=
None
,
nested_detection
=
None
):
if
memo
is
None
:
memo
=
set
()
if
root
is
None
:
if
root
is
None
:
root
=
self
.
model
root
=
StructuredMutableTreeNode
(
None
)
# if distinct is true, the method will filter out those with duplicated keys
if
module
not
in
memo
:
key2module
=
dict
()
memo
.
add
(
module
)
for
name
,
module
in
root
.
named_modules
():
if
isinstance
(
module
,
Mutable
):
if
isinstance
(
module
,
Mutable
):
module_distinct
=
False
if
nested_detection
is
not
None
:
if
module
.
key
in
key2module
:
raise
RuntimeError
(
"Cannot have nested search space. Error at {} in {}"
assert
key2module
[
module
.
key
].
similar
(
module
),
\
.
format
(
module
,
nested_detection
))
"Mutable
\"
{}
\"
that share the same key must be similar to each other"
.
format
(
module
.
key
)
module
.
name
=
prefix
else
:
module
.
set_mutator
(
self
)
module_distinct
=
True
root
=
root
.
add_child
(
module
)
key2module
[
module
.
key
]
=
module
if
not
isinstance
(
module
,
MutableScope
):
if
distinct
:
nested_detection
=
module
if
module_distinct
:
if
isinstance
(
module
,
InputChoice
):
yield
name
,
module
for
k
in
module
.
choose_from
:
else
:
if
k
!=
InputChoice
.
NO_KEY
and
k
not
in
[
m
.
key
for
m
in
memo
if
isinstance
(
m
,
Mutable
)]:
yield
name
,
module
,
module_distinct
raise
RuntimeError
(
"'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY."
.
format
(
k
,
module
.
key
))
def
__setattr__
(
self
,
key
,
value
):
for
name
,
submodule
in
module
.
_modules
.
items
():
if
key
in
[
"model"
,
"net"
,
"network"
]:
if
submodule
is
None
:
logger
.
warning
(
"Think twice if you are including the network into mutator."
)
continue
return
super
().
__setattr__
(
key
,
value
)
submodule_prefix
=
prefix
+
(
"."
if
prefix
else
""
)
+
name
self
.
_parse_search_space
(
submodule
,
root
,
submodule_prefix
,
memo
=
memo
,
nested_detection
=
nested_detection
)
return
root
@
property
def
mutables
(
self
):
return
self
.
_structured_mutables
@
property
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
raise
NotImplementedError
(
"Mutator is not forward-able
"
)
raise
RuntimeError
(
"Forward is undefined for mutators.
"
)
def
enter_mutable_scope
(
self
,
mutable_scope
):
def
enter_mutable_scope
(
self
,
mutable_scope
):
"""
Callback when forward of a MutableScope is entered.
Parameters
----------
mutable_scope: MutableScope
Returns
-------
None
"""
pass
pass
def
exit_mutable_scope
(
self
,
mutable_scope
):
def
exit_mutable_scope
(
self
,
mutable_scope
):
"""
Callback when forward of a MutableScope is exited.
Parameters
----------
mutable_scope: MutableScope
Returns
-------
None
"""
pass
pass
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
"""
Callbacks of forward in LayerChoice.
Parameters
----------
mutable: LayerChoice
inputs: list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
output tensor and mask
"""
raise
NotImplementedError
raise
NotImplementedError
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
,
tags
):
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
"""
Callbacks of forward in InputChoice.
Parameters
----------
mutable: InputChoice
tensor_list: list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
output tensor and mask
"""
raise
NotImplementedError
raise
NotImplementedError
def
export
(
self
):
def
export
(
self
):
"""
Export the data of all decisions. This should output the decisions of all the mutables, so that the whole
network can be fully determined with these decisions for further training from scratch.
Returns
-------
dict
"""
raise
NotImplementedError
raise
NotImplementedError
src/sdk/pynni/nni/nas/pytorch/base_trainer.py
View file @
77e91e8b
...
@@ -12,5 +12,9 @@ class BaseTrainer(ABC):
...
@@ -12,5 +12,9 @@ class BaseTrainer(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
train_and_validate
(
self
):
def
export
(
self
,
file
):
raise
NotImplementedError
@
abstractmethod
def
checkpoint
(
self
):
raise
NotImplementedError
raise
NotImplementedError
src/sdk/pynni/nni/nas/pytorch/callbacks.py
View file @
77e91e8b
import
json
import
logging
import
logging
import
os
import
os
import
torch
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -44,26 +41,11 @@ class LearningRateScheduler(Callback):
...
@@ -44,26 +41,11 @@ class LearningRateScheduler(Callback):
class
ArchitectureCheckpoint
(
Callback
):
class
ArchitectureCheckpoint
(
Callback
):
class
TorchTensorEncoder
(
json
.
JSONEncoder
):
def
default
(
self
,
o
):
# pylint: disable=method-hidden
if
isinstance
(
o
,
torch
.
Tensor
):
olist
=
o
.
tolist
()
if
"bool"
not
in
o
.
type
().
lower
()
and
all
(
map
(
lambda
d
:
d
==
0
or
d
==
1
,
olist
)):
_logger
.
warning
(
"Every element in %s is either 0 or 1. "
"You might consider convert it into bool."
,
olist
)
return
olist
return
super
().
default
(
o
)
def
__init__
(
self
,
checkpoint_dir
,
every
=
"epoch"
):
def
__init__
(
self
,
checkpoint_dir
,
every
=
"epoch"
):
super
().
__init__
()
super
().
__init__
()
assert
every
==
"epoch"
assert
every
==
"epoch"
self
.
checkpoint_dir
=
checkpoint_dir
self
.
checkpoint_dir
=
checkpoint_dir
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
def
_export_to_file
(
self
,
file
):
mutator_export
=
self
.
mutator
.
export
()
with
open
(
file
,
"w"
)
as
f
:
json
.
dump
(
mutator_export
,
f
,
indent
=
2
,
sort_keys
=
True
,
cls
=
self
.
TorchTensorEncoder
)
def
on_epoch_end
(
self
,
epoch
):
def
on_epoch_end
(
self
,
epoch
):
self
.
_export_to_file
(
os
.
path
.
join
(
self
.
checkpoint_dir
,
"epoch_{}.json"
.
format
(
epoch
)))
self
.
trainer
.
export
(
os
.
path
.
join
(
self
.
checkpoint_dir
,
"epoch_{}.json"
.
format
(
epoch
)))
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
View file @
77e91e8b
from
.mutator
import
DartsMutator
from
.mutator
import
DartsMutator
from
.trainer
import
DartsTrainer
from
.trainer
import
DartsTrainer
from
.scope
import
DartsNode
\ No newline at end of file
\ No newline at end of file
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
View file @
77e91e8b
...
@@ -2,35 +2,47 @@ import torch
...
@@ -2,35 +2,47 @@ import torch
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
nni.nas.pytorch.mutables
import
LayerChoice
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutator
import
Mutator
from
.scope
import
DartsNod
e
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoic
e
class
DartsMutator
(
Mutator
):
class
DartsMutator
(
Mutator
):
def
__init__
(
self
,
model
):
def
after_parse_search_space
(
self
):
super
().
__init__
(
model
)
self
.
choices
=
nn
.
ParameterDict
()
self
.
choices
=
nn
.
ParameterDict
()
for
_
,
mutable
in
self
.
named_
mutables
()
:
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
if
isinstance
(
mutable
,
LayerChoice
):
self
.
choices
[
mutable
.
key
]
=
nn
.
Parameter
(
1.0E-3
*
torch
.
randn
(
len
(
mutable
)
+
1
))
self
.
choices
[
mutable
.
key
]
=
nn
.
Parameter
(
1.0E-3
*
torch
.
randn
(
mutable
.
length
+
1
))
def
on_calc_layer_choice_mask
(
self
,
mutable
:
LayerChoice
):
def
device
(
self
):
return
F
.
softmax
(
self
.
choices
[
mutable
.
key
],
dim
=-
1
)[:
-
1
]
for
v
in
self
.
choices
.
values
():
return
v
.
device
def
export
(
self
):
def
sample_search
(
self
):
result
=
super
().
export
()
result
=
dict
()
for
_
,
darts_node
in
self
.
named_mutables
():
for
mutable
in
self
.
mutables
:
if
isinstance
(
darts_node
,
DartsNode
):
if
isinstance
(
mutable
,
LayerChoice
):
keys
,
edges_max
=
[],
[]
# key of all the layer choices in current node, and their best edge weight
result
[
mutable
.
key
]
=
F
.
softmax
(
self
.
choices
[
mutable
.
key
],
dim
=-
1
)[:
-
1
]
for
_
,
choice
in
self
.
named_mutables
(
darts_node
):
elif
isinstance
(
mutable
,
InputChoice
):
if
isinstance
(
choice
,
LayerChoice
):
result
[
mutable
.
key
]
=
torch
.
ones
(
mutable
.
n_candidates
,
dtype
=
torch
.
bool
,
device
=
self
.
device
())
keys
.
append
(
choice
.
key
)
return
result
max_val
,
index
=
torch
.
max
(
result
[
choice
.
key
],
0
)
edges_max
.
append
(
max_val
)
def
sample_final
(
self
):
result
[
choice
.
key
]
=
F
.
one_hot
(
index
,
num_classes
=
len
(
result
[
choice
.
key
])).
view
(
-
1
).
bool
()
result
=
dict
()
_
,
topk_edge_indices
=
torch
.
topk
(
torch
.
tensor
(
edges_max
).
view
(
-
1
),
darts_node
.
limitation
)
# pylint: disable=not-callable
edges_max
=
dict
()
for
i
,
key
in
enumerate
(
keys
):
for
mutable
in
self
.
mutables
:
if
i
not
in
topk_edge_indices
:
if
isinstance
(
mutable
,
LayerChoice
):
result
[
key
]
=
torch
.
zeros_like
(
result
[
key
])
max_val
,
index
=
torch
.
max
(
F
.
softmax
(
self
.
choices
[
mutable
.
key
],
dim
=-
1
)[:
-
1
],
0
)
edges_max
[
mutable
.
key
]
=
max_val
result
[
mutable
.
key
]
=
F
.
one_hot
(
index
,
num_classes
=
mutable
.
length
).
view
(
-
1
).
bool
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
InputChoice
):
weights
=
torch
.
tensor
([
edges_max
.
get
(
src_key
,
0.
)
for
src_key
in
mutable
.
choose_from
])
# pylint: disable=not-callable
_
,
topk_edge_indices
=
torch
.
topk
(
weights
,
mutable
.
n_chosen
or
mutable
.
n_candidates
)
selected_multihot
=
[]
for
i
,
src_key
in
enumerate
(
mutable
.
choose_from
):
if
i
not
in
topk_edge_indices
and
src_key
in
result
:
result
[
src_key
]
=
torch
.
zeros_like
(
result
[
src_key
])
# clear this choice to optimize calc graph
selected_multihot
.
append
(
i
in
topk_edge_indices
)
result
[
mutable
.
key
]
=
torch
.
tensor
(
selected_multihot
,
dtype
=
torch
.
bool
,
device
=
self
.
device
())
# pylint: disable=not-callable
return
result
return
result
src/sdk/pynni/nni/nas/pytorch/darts/scope.py
deleted
100644 → 0
View file @
9dda5370
from
nni.nas.pytorch.mutables
import
MutableScope
class
DartsNode
(
MutableScope
):
"""
At most `limitation` choice is activated in a `DartsNode` when exporting.
"""
def
__init__
(
self
,
key
,
limitation
):
super
().
__init__
(
key
)
self
.
limitation
=
limitation
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
77e91e8b
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.utils
import
AverageMeterGroup
from
nni.nas.
pytorch.
utils
import
AverageMeterGroup
from
.mutator
import
DartsMutator
from
.mutator
import
DartsMutator
...
@@ -13,9 +13,9 @@ class DartsTrainer(Trainer):
...
@@ -13,9 +13,9 @@ class DartsTrainer(Trainer):
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
):
callbacks
=
None
):
super
().
__init__
(
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
DartsMutator
(
model
)
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
if
mutator
is
not
None
else
DartsMutator
(
model
)
,
callbacks
)
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
self
.
ctrl_optim
=
torch
.
optim
.
Adam
(
self
.
mutator
.
parameters
(),
3.0E-4
,
betas
=
(
0.5
,
0.999
),
self
.
ctrl_optim
=
torch
.
optim
.
Adam
(
self
.
mutator
.
parameters
(),
3.0E-4
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
weight_decay
=
1.0E-3
)
n_train
=
len
(
self
.
dataset_train
)
n_train
=
len
(
self
.
dataset_train
)
...
@@ -31,6 +31,9 @@ class DartsTrainer(Trainer):
...
@@ -31,6 +31,9 @@ class DartsTrainer(Trainer):
batch_size
=
batch_size
,
batch_size
=
batch_size
,
sampler
=
valid_sampler
,
sampler
=
valid_sampler
,
num_workers
=
workers
)
num_workers
=
workers
)
self
.
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_valid
,
batch_size
=
batch_size
,
num_workers
=
workers
)
def
train_one_epoch
(
self
,
epoch
):
def
train_one_epoch
(
self
,
epoch
):
self
.
model
.
train
()
self
.
model
.
train
()
...
@@ -47,8 +50,8 @@ class DartsTrainer(Trainer):
...
@@ -47,8 +50,8 @@ class DartsTrainer(Trainer):
# phase 1. child network step
# phase 1. child network step
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
with
self
.
mutator
.
forward_pass
()
:
self
.
mutator
.
reset
()
logits
=
self
.
model
(
trn_X
)
logits
=
self
.
model
(
trn_X
)
loss
=
self
.
loss
(
logits
,
trn_y
)
loss
=
self
.
loss
(
logits
,
trn_y
)
loss
.
backward
()
loss
.
backward
()
# gradient clipping
# gradient clipping
...
@@ -76,10 +79,10 @@ class DartsTrainer(Trainer):
...
@@ -76,10 +79,10 @@ class DartsTrainer(Trainer):
self
.
mutator
.
eval
()
self
.
mutator
.
eval
()
meters
=
AverageMeterGroup
()
meters
=
AverageMeterGroup
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
step
,
(
X
,
y
)
in
enumerate
(
self
.
valid_loader
):
self
.
mutator
.
reset
()
for
step
,
(
X
,
y
)
in
enumerate
(
self
.
test_loader
):
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
with
self
.
mutator
.
forward_pass
():
logits
=
self
.
model
(
X
)
logits
=
self
.
model
(
X
)
metrics
=
self
.
metrics
(
logits
,
y
)
metrics
=
self
.
metrics
(
logits
,
y
)
meters
.
update
(
metrics
)
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
...
@@ -93,8 +96,8 @@ class DartsTrainer(Trainer):
...
@@ -93,8 +96,8 @@ class DartsTrainer(Trainer):
v_model: backup model before this step
v_model: backup model before this step
lr: learning rate for virtual gradient step (same as net lr)
lr: learning rate for virtual gradient step (same as net lr)
"""
"""
with
self
.
mutator
.
forward_pass
()
:
self
.
mutator
.
reset
()
loss
=
self
.
loss
(
self
.
model
(
val_X
),
val_y
)
loss
=
self
.
loss
(
self
.
model
(
val_X
),
val_y
)
w_model
=
tuple
(
self
.
model
.
parameters
())
w_model
=
tuple
(
self
.
model
.
parameters
())
w_ctrl
=
tuple
(
self
.
mutator
.
parameters
())
w_ctrl
=
tuple
(
self
.
mutator
.
parameters
())
w_grads
=
torch
.
autograd
.
grad
(
loss
,
w_model
+
w_ctrl
)
w_grads
=
torch
.
autograd
.
grad
(
loss
,
w_model
+
w_ctrl
)
...
@@ -125,8 +128,8 @@ class DartsTrainer(Trainer):
...
@@ -125,8 +128,8 @@ class DartsTrainer(Trainer):
for
p
,
d
in
zip
(
self
.
model
.
parameters
(),
dw
):
for
p
,
d
in
zip
(
self
.
model
.
parameters
(),
dw
):
p
+=
eps
*
d
p
+=
eps
*
d
with
self
.
mutator
.
forward_pass
()
:
self
.
mutator
.
reset
()
loss
=
self
.
loss
(
self
.
model
(
trn_X
),
trn_y
)
loss
=
self
.
loss
(
self
.
model
(
trn_X
),
trn_y
)
if
e
>
0
:
if
e
>
0
:
dalpha_pos
=
torch
.
autograd
.
grad
(
loss
,
self
.
mutator
.
parameters
())
# dalpha { L_trn(w+) }
dalpha_pos
=
torch
.
autograd
.
grad
(
loss
,
self
.
mutator
.
parameters
())
# dalpha { L_trn(w+) }
elif
e
<
0
:
elif
e
<
0
:
...
...
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
View file @
77e91e8b
...
@@ -2,8 +2,8 @@ import torch
...
@@ -2,8 +2,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
nni.nas.pytorch.mutables
import
LayerChoice
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
,
MutableScope
class
StackedLSTMCell
(
nn
.
Module
):
class
StackedLSTMCell
(
nn
.
Module
):
...
@@ -27,15 +27,14 @@ class StackedLSTMCell(nn.Module):
...
@@ -27,15 +27,14 @@ class StackedLSTMCell(nn.Module):
class
EnasMutator
(
Mutator
):
class
EnasMutator
(
Mutator
):
def
__init__
(
self
,
model
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
cell_exit_extra_step
=
False
,
def
__init__
(
self
,
model
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
cell_exit_extra_step
=
False
,
skip_target
=
0.4
,
branch_bias
=
0.25
):
skip_target
=
0.4
,
branch_bias
=
0.25
):
super
().
__init__
(
model
)
self
.
lstm_size
=
lstm_size
self
.
lstm_size
=
lstm_size
self
.
lstm_num_layers
=
lstm_num_layers
self
.
lstm_num_layers
=
lstm_num_layers
self
.
tanh_constant
=
tanh_constant
self
.
tanh_constant
=
tanh_constant
self
.
cell_exit_extra_step
=
cell_exit_extra_step
self
.
cell_exit_extra_step
=
cell_exit_extra_step
self
.
skip_target
=
skip_target
self
.
skip_target
=
skip_target
self
.
branch_bias
=
branch_bias
self
.
branch_bias
=
branch_bias
super
().
__init__
(
model
)
def
before_parse_search_space
(
self
):
self
.
lstm
=
StackedLSTMCell
(
self
.
lstm_num_layers
,
self
.
lstm_size
,
False
)
self
.
lstm
=
StackedLSTMCell
(
self
.
lstm_num_layers
,
self
.
lstm_size
,
False
)
self
.
attn_anchor
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
attn_anchor
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
...
@@ -45,9 +44,8 @@ class EnasMutator(Mutator):
...
@@ -45,9 +44,8 @@ class EnasMutator(Mutator):
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
"none"
)
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
"none"
)
self
.
bias_dict
=
nn
.
ParameterDict
()
self
.
bias_dict
=
nn
.
ParameterDict
()
def
after_parse_search_space
(
self
):
self
.
max_layer_choice
=
0
self
.
max_layer_choice
=
0
for
_
,
mutable
in
self
.
named_
mutables
()
:
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
if
isinstance
(
mutable
,
LayerChoice
):
if
self
.
max_layer_choice
==
0
:
if
self
.
max_layer_choice
==
0
:
self
.
max_layer_choice
=
mutable
.
length
self
.
max_layer_choice
=
mutable
.
length
...
@@ -64,8 +62,29 @@ class EnasMutator(Mutator):
...
@@ -64,8 +62,29 @@ class EnasMutator(Mutator):
self
.
embedding
=
nn
.
Embedding
(
self
.
max_layer_choice
+
1
,
self
.
lstm_size
)
self
.
embedding
=
nn
.
Embedding
(
self
.
max_layer_choice
+
1
,
self
.
lstm_size
)
self
.
soft
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
max_layer_choice
,
bias
=
False
)
self
.
soft
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
max_layer_choice
,
bias
=
False
)
def
before_pass
(
self
):
def
sample_search
(
self
):
super
().
before_pass
()
self
.
_initialize
()
self
.
_sample
(
self
.
mutables
)
return
self
.
_choices
def
sample_final
(
self
):
return
self
.
sample_search
()
def
_sample
(
self
,
tree
):
mutable
=
tree
.
mutable
if
isinstance
(
mutable
,
LayerChoice
)
and
mutable
.
key
not
in
self
.
_choices
:
self
.
_choices
[
mutable
.
key
]
=
self
.
_sample_layer_choice
(
mutable
)
elif
isinstance
(
mutable
,
InputChoice
)
and
mutable
.
key
not
in
self
.
_choices
:
self
.
_choices
[
mutable
.
key
]
=
self
.
_sample_input_choice
(
mutable
)
for
child
in
tree
.
children
:
self
.
_sample
(
child
)
if
isinstance
(
mutable
,
MutableScope
)
and
mutable
.
key
not
in
self
.
_anchors_hid
:
if
self
.
cell_exit_extra_step
:
self
.
_lstm_next_step
()
self
.
_mark_anchor
(
mutable
.
key
)
def
_initialize
(
self
):
self
.
_choices
=
dict
()
self
.
_anchors_hid
=
dict
()
self
.
_anchors_hid
=
dict
()
self
.
_inputs
=
self
.
g_emb
.
data
self
.
_inputs
=
self
.
g_emb
.
data
self
.
_c
=
[
torch
.
zeros
((
1
,
self
.
lstm_size
),
self
.
_c
=
[
torch
.
zeros
((
1
,
self
.
lstm_size
),
...
@@ -84,7 +103,7 @@ class EnasMutator(Mutator):
...
@@ -84,7 +103,7 @@ class EnasMutator(Mutator):
def
_mark_anchor
(
self
,
key
):
def
_mark_anchor
(
self
,
key
):
self
.
_anchors_hid
[
key
]
=
self
.
_h
[
-
1
]
self
.
_anchors_hid
[
key
]
=
self
.
_h
[
-
1
]
def
on_calc
_layer_choice
_mask
(
self
,
mutable
):
def
_sample
_layer_choice
(
self
,
mutable
):
self
.
_lstm_next_step
()
self
.
_lstm_next_step
()
logit
=
self
.
soft
(
self
.
_h
[
-
1
])
logit
=
self
.
soft
(
self
.
_h
[
-
1
])
if
self
.
tanh_constant
is
not
None
:
if
self
.
tanh_constant
is
not
None
:
...
@@ -94,14 +113,14 @@ class EnasMutator(Mutator):
...
@@ -94,14 +113,14 @@ class EnasMutator(Mutator):
branch_id
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
branch_id
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
branch_id
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
branch_id
)
self
.
sample_log_prob
+=
torch
.
sum
(
log_prob
)
self
.
sample_log_prob
+=
torch
.
sum
(
log_prob
)
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
self
.
sample_entropy
+=
torch
.
sum
(
entropy
)
self
.
sample_entropy
+=
torch
.
sum
(
entropy
)
self
.
_inputs
=
self
.
embedding
(
branch_id
)
self
.
_inputs
=
self
.
embedding
(
branch_id
)
return
F
.
one_hot
(
branch_id
,
num_classes
=
self
.
max_layer_choice
).
bool
().
view
(
-
1
)
return
F
.
one_hot
(
branch_id
,
num_classes
=
self
.
max_layer_choice
).
bool
().
view
(
-
1
)
def
on_calc
_input_choice
_mask
(
self
,
mutable
,
tags
):
def
_sample
_input_choice
(
self
,
mutable
):
query
,
anchors
=
[],
[]
query
,
anchors
=
[],
[]
for
label
in
tags
:
for
label
in
mutable
.
choose_from
:
if
label
not
in
self
.
_anchors_hid
:
if
label
not
in
self
.
_anchors_hid
:
self
.
_lstm_next_step
()
self
.
_lstm_next_step
()
self
.
_mark_anchor
(
label
)
# empty loop, fill not found
self
.
_mark_anchor
(
label
)
# empty loop, fill not found
...
@@ -113,8 +132,8 @@ class EnasMutator(Mutator):
...
@@ -113,8 +132,8 @@ class EnasMutator(Mutator):
if
self
.
tanh_constant
is
not
None
:
if
self
.
tanh_constant
is
not
None
:
query
=
self
.
tanh_constant
*
torch
.
tanh
(
query
)
query
=
self
.
tanh_constant
*
torch
.
tanh
(
query
)
if
mutable
.
n_
selected
is
None
:
if
mutable
.
n_
chosen
is
None
:
logit
=
torch
.
cat
([
-
query
,
query
],
1
)
logit
=
torch
.
cat
([
-
query
,
query
],
1
)
# pylint: disable=invalid-unary-operand-type
skip
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
skip
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
skip_prob
=
torch
.
sigmoid
(
logit
)
skip_prob
=
torch
.
sigmoid
(
logit
)
...
@@ -123,19 +142,14 @@ class EnasMutator(Mutator):
...
@@ -123,19 +142,14 @@ class EnasMutator(Mutator):
log_prob
=
self
.
cross_entropy_loss
(
logit
,
skip
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
skip
)
self
.
_inputs
=
(
torch
.
matmul
(
skip
.
float
(),
torch
.
cat
(
anchors
,
0
))
/
(
1.
+
torch
.
sum
(
skip
))).
unsqueeze
(
0
)
self
.
_inputs
=
(
torch
.
matmul
(
skip
.
float
(),
torch
.
cat
(
anchors
,
0
))
/
(
1.
+
torch
.
sum
(
skip
))).
unsqueeze
(
0
)
else
:
else
:
assert
mutable
.
n_
selected
==
1
,
"Input choice must select exactly one or any in ENAS."
assert
mutable
.
n_
chosen
==
1
,
"Input choice must select exactly one or any in ENAS."
logit
=
query
.
view
(
1
,
-
1
)
logit
=
query
.
view
(
1
,
-
1
)
index
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
index
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
skip
=
F
.
one_hot
(
index
).
view
(
-
1
)
skip
=
F
.
one_hot
(
index
,
num_classes
=
mutable
.
n_candidates
).
view
(
-
1
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
index
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
index
)
self
.
_inputs
=
anchors
[
index
.
item
()]
self
.
_inputs
=
anchors
[
index
.
item
()]
self
.
sample_log_prob
+=
torch
.
sum
(
log_prob
)
self
.
sample_log_prob
+=
torch
.
sum
(
log_prob
)
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
self
.
sample_entropy
+=
torch
.
sum
(
entropy
)
self
.
sample_entropy
+=
torch
.
sum
(
entropy
)
return
skip
.
bool
()
return
skip
.
bool
()
def
exit_mutable_scope
(
self
,
mutable_scope
):
if
self
.
cell_exit_extra_step
:
self
.
_lstm_next_step
()
self
.
_mark_anchor
(
mutable_scope
.
key
)
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
View file @
77e91e8b
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
import
torch.optim
as
optim
import
torch.optim
as
optim
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.utils
import
AverageMeterGroup
from
nni.nas.
pytorch.
utils
import
AverageMeterGroup
from
.mutator
import
EnasMutator
from
.mutator
import
EnasMutator
...
@@ -12,9 +12,9 @@ class EnasTrainer(Trainer):
...
@@ -12,9 +12,9 @@ class EnasTrainer(Trainer):
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
mutator_lr
=
0.00035
,
mutator_steps_aggregate
=
20
,
mutator_steps
=
50
,
aux_weight
=
0.4
):
mutator_lr
=
0.00035
,
mutator_steps_aggregate
=
20
,
mutator_steps
=
50
,
aux_weight
=
0.4
):
super
().
__init__
(
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
EnasMutator
(
model
)
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
if
mutator
is
not
None
else
EnasMutator
(
model
)
,
callbacks
)
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
self
.
reward_function
=
reward_function
self
.
reward_function
=
reward_function
self
.
mutator_optim
=
optim
.
Adam
(
self
.
mutator
.
parameters
(),
lr
=
mutator_lr
)
self
.
mutator_optim
=
optim
.
Adam
(
self
.
mutator
.
parameters
(),
lr
=
mutator_lr
)
...
@@ -52,8 +52,9 @@ class EnasTrainer(Trainer):
...
@@ -52,8 +52,9 @@ class EnasTrainer(Trainer):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
with
self
.
mutator
.
forward_pass
():
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
if
isinstance
(
logits
,
tuple
):
if
isinstance
(
logits
,
tuple
):
logits
,
aux_logits
=
logits
logits
,
aux_logits
=
logits
...
@@ -81,7 +82,8 @@ class EnasTrainer(Trainer):
...
@@ -81,7 +82,8 @@ class EnasTrainer(Trainer):
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
valid_loader
):
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
valid_loader
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
with
self
.
mutator
.
forward_pass
():
self
.
mutator
.
reset
()
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
logits
=
self
.
model
(
x
)
metrics
=
self
.
metrics
(
logits
,
y
)
metrics
=
self
.
metrics
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
...
@@ -107,9 +109,9 @@ class EnasTrainer(Trainer):
...
@@ -107,9 +109,9 @@ class EnasTrainer(Trainer):
self
.
mutator_optim
.
zero_grad
()
self
.
mutator_optim
.
zero_grad
()
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
print
(
"
Mutator
Epoch [{}/{}] Step [{}/{}] {}"
.
format
(
epoch
,
self
.
num_epochs
,
print
(
"
RL
Epoch [{}/{}] Step [{}/{}] {}"
.
format
(
epoch
,
self
.
num_epochs
,
mutator_step
//
self
.
mutator_steps_aggregate
,
mutator_step
//
self
.
mutator_steps_aggregate
,
self
.
mutator_steps
,
meters
))
self
.
mutator_steps
,
meters
))
mutator_step
+=
1
mutator_step
+=
1
if
mutator_step
>=
total_mutator_steps
:
if
mutator_step
>=
total_mutator_steps
:
break
break
...
...
src/sdk/pynni/nni/nas/pytorch/fixed.py
View file @
77e91e8b
...
@@ -2,10 +2,12 @@ import json
...
@@ -2,10 +2,12 @@ import json
import
torch
import
torch
from
nni.nas.pytorch.mutables
import
MutableScope
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutator
import
Mutator
class
FixedArchitecture
(
Mutator
):
class
FixedArchitecture
(
Mutator
):
def
__init__
(
self
,
model
,
fixed_arc
,
strict
=
True
):
def
__init__
(
self
,
model
,
fixed_arc
,
strict
=
True
):
"""
"""
Initialize a fixed architecture mutator.
Initialize a fixed architecture mutator.
...
@@ -20,39 +22,57 @@ class FixedArchitecture(Mutator):
...
@@ -20,39 +22,57 @@ class FixedArchitecture(Mutator):
Force everything that appears in `fixed_arc` to be used at least once.
Force everything that appears in `fixed_arc` to be used at least once.
"""
"""
super
().
__init__
(
model
)
super
().
__init__
(
model
)
if
isinstance
(
fixed_arc
,
str
):
with
open
(
fixed_arc
,
"r"
)
as
f
:
fixed_arc
=
json
.
load
(
f
.
read
())
self
.
_fixed_arc
=
fixed_arc
self
.
_fixed_arc
=
fixed_arc
self
.
_strict
=
strict
mutable_keys
=
set
([
mutable
.
key
for
mutable
in
self
.
mutables
if
not
isinstance
(
mutable
,
MutableScope
)])
def
_encode_tensor
(
self
,
data
):
fixed_arc_keys
=
set
(
self
.
_fixed_arc
.
keys
())
if
isinstance
(
data
,
list
):
if
fixed_arc_keys
-
mutable_keys
:
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
raise
RuntimeError
(
"Unexpected keys found in fixed architecture: {}."
.
format
(
fixed_arc_keys
-
mutable_keys
))
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
if
mutable_keys
-
fixed_arc_keys
:
else
:
raise
RuntimeError
(
"Missing keys in fixed architecture: {}."
.
format
(
mutable_keys
-
fixed_arc_keys
))
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
)
# pylint: disable=not-callable
if
isinstance
(
data
,
dict
):
def
sample_search
(
self
):
return
{
k
:
self
.
_encode_tensor
(
v
)
for
k
,
v
in
data
.
items
()}
return
self
.
_fixed_arc
return
data
def
sample_final
(
self
):
def
before_pass
(
self
):
return
self
.
_fixed_arc
self
.
_unused_key
=
set
(
self
.
_fixed_arc
.
keys
())
def
after_pass
(
self
):
def
_encode_tensor
(
data
,
device
):
if
self
.
_strict
:
if
isinstance
(
data
,
list
):
if
self
.
_unused_key
:
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
raise
ValueError
(
"{} are never used by the network. "
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
,
device
=
device
)
# pylint: disable=not-callable
"Set strict=False if you want to disable this check."
.
format
(
self
.
_unused_key
))
else
:
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
,
device
=
device
)
# pylint: disable=not-callable
def
_check_key
(
self
,
key
):
if
isinstance
(
data
,
dict
):
if
key
not
in
self
.
_fixed_arc
:
return
{
k
:
_encode_tensor
(
v
,
device
)
for
k
,
v
in
data
.
items
()}
raise
ValueError
(
"
\"
{}
\"
is demanded by the network, but not found in saved architecture."
.
format
(
key
))
return
data
def
on_calc_layer_choice_mask
(
self
,
mutable
):
self
.
_check_key
(
mutable
.
key
)
def
apply_fixed_architecture
(
model
,
fixed_arc_path
,
device
=
None
):
return
self
.
_fixed_arc
[
mutable
.
key
]
"""
Load architecture from `fixed_arc_path` and apply to model.
def
on_calc_input_choice_mask
(
self
,
mutable
,
tags
):
self
.
_check_key
(
mutable
.
key
)
Parameters
return
self
.
_fixed_arc
[
mutable
.
key
]
----------
model: torch.nn.Module
Model with mutables.
fixed_arc_path: str
Path to the JSON that stores the architecture.
device: torch.device
Architecture weights will be transfered to `device`.
Returns
-------
FixedArchitecture
"""
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
if
isinstance
(
fixed_arc_path
,
str
):
with
open
(
fixed_arc_path
,
"r"
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
_encode_tensor
(
fixed_arc
,
device
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
.
to
(
device
)
architecture
.
reset
()
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
77e91e8b
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.nas.utils
import
global_mutable_counting
from
nni.nas.
pytorch.
utils
import
global_mutable_counting
class
Mutable
(
nn
.
Module
):
class
Mutable
(
nn
.
Module
):
...
@@ -37,7 +37,7 @@ class Mutable(nn.Module):
...
@@ -37,7 +37,7 @@ class Mutable(nn.Module):
self
.
__dict__
[
"mutator"
]
=
mutator
self
.
__dict__
[
"mutator"
]
=
mutator
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
raise
NotImplementedError
(
"Mutable forward must be implemented."
)
raise
NotImplementedError
@
property
@
property
def
key
(
self
):
def
key
(
self
):
...
@@ -51,9 +51,6 @@ class Mutable(nn.Module):
...
@@ -51,9 +51,6 @@ class Mutable(nn.Module):
def
name
(
self
,
name
):
def
name
(
self
,
name
):
self
.
_name
=
name
self
.
_name
=
name
def
similar
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
_check_built
(
self
):
def
_check_built
(
self
):
if
not
hasattr
(
self
,
"mutator"
):
if
not
hasattr
(
self
,
"mutator"
):
raise
ValueError
(
raise
ValueError
(
...
@@ -66,19 +63,17 @@ class Mutable(nn.Module):
...
@@ -66,19 +63,17 @@ class Mutable(nn.Module):
class
MutableScope
(
Mutable
):
class
MutableScope
(
Mutable
):
"""
"""
Mutable scope labels a subgraph to help mutators make better decisions.
Mutators get notified when a mutable scope
Mutable scope labels a subgraph
/submodule
to help mutators make better decisions.
is entered and exited. Mutators can override ``enter_mutable_scope``
and ``exit_mutable_scope`` to catch
Mutators get notified when a mutable scope
is entered and exited. Mutators can override ``enter_mutable_scope``
corresponding events, and do status dump or update.
and ``exit_mutable_scope`` to catch
corresponding events, and do status dump or update.
"""
"""
def
__init__
(
self
,
key
):
def
__init__
(
self
,
key
):
super
().
__init__
(
key
=
key
)
super
().
__init__
(
key
=
key
)
def
build
(
self
):
self
.
mutator
.
on_init_mutable_scope
(
self
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
try
:
try
:
self
.
_check_built
()
self
.
mutator
.
enter_mutable_scope
(
self
)
self
.
mutator
.
enter_mutable_scope
(
self
)
return
super
().
__call__
(
*
args
,
**
kwargs
)
return
super
().
__call__
(
*
args
,
**
kwargs
)
finally
:
finally
:
...
@@ -93,43 +88,92 @@ class LayerChoice(Mutable):
...
@@ -93,43 +88,92 @@ class LayerChoice(Mutable):
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
self
.
return_mask
=
return_mask
def
__len__
(
self
):
return
len
(
self
.
choices
)
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
out
,
mask
=
self
.
mutator
.
on_forward_layer_choice
(
self
,
*
inputs
)
out
,
mask
=
self
.
mutator
.
on_forward_layer_choice
(
self
,
*
inputs
)
if
self
.
return_mask
:
if
self
.
return_mask
:
return
out
,
mask
return
out
,
mask
return
out
return
out
def
similar
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
length
==
other
.
length
class
InputChoice
(
Mutable
):
class
InputChoice
(
Mutable
):
def
__init__
(
self
,
n_candidates
,
n_selected
=
None
,
reduction
=
"mean"
,
return_mask
=
False
,
key
=
None
):
"""
Input choice selects `n_chosen` inputs from `choose_from` (contains `n_candidates` keys). For beginners,
use `n_candidates` instead of `choose_from` is a safe option. To get the most power out of it, you might want to
know about `choose_from`.
The keys in `choose_from` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones.
The keys are designed to be the keys of the sources. To help mutators make better decisions,
mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the
output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g.,
``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a
module/submodule, it needs to be annotated with a key: that's where a ``MutableScope`` is needed.
"""
NO_KEY
=
""
def
__init__
(
self
,
n_candidates
=
None
,
choose_from
=
None
,
n_chosen
=
None
,
reduction
=
"mean"
,
return_mask
=
False
,
key
=
None
):
"""
Initialization.
Parameters
----------
n_candidates: int
Number of inputs to choose from.
choose_from: list of str
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled.
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates`
number of empty string.
n_chosen: int
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction: str
`mean`, `concat`, `sum` or `none`.
return_mask: bool
If `return_mask`, return output tensor and a mask. Otherwise return tensor only.
key: str
Key of the input choice.
"""
super
().
__init__
(
key
=
key
)
super
().
__init__
(
key
=
key
)
# precondition check
assert
n_candidates
is
not
None
or
choose_from
is
not
None
,
"At least one of `n_candidates` and `choose_from`"
\
"must be not None."
if
choose_from
is
not
None
and
n_candidates
is
None
:
n_candidates
=
len
(
choose_from
)
elif
choose_from
is
None
and
n_candidates
is
not
None
:
choose_from
=
[
self
.
NO_KEY
]
*
n_candidates
assert
n_candidates
==
len
(
choose_from
),
"Number of candidates must be equal to the length of `choose_from`."
assert
n_candidates
>
0
,
"Number of candidates must be greater than 0."
assert
n_candidates
>
0
,
"Number of candidates must be greater than 0."
assert
n_chosen
is
None
or
0
<=
n_chosen
<=
n_candidates
,
"Expected selected number must be None or no more "
\
"than number of candidates."
self
.
n_candidates
=
n_candidates
self
.
n_candidates
=
n_candidates
self
.
n_selected
=
n_selected
self
.
choose_from
=
choose_from
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
self
.
return_mask
=
return_mask
def
build
(
self
):
def
forward
(
self
,
optional_inputs
):
self
.
mutator
.
on_init_input_choice
(
self
)
"""
Forward method of LayerChoice.
def
forward
(
self
,
optional_inputs
,
tags
=
None
):
Parameters
----------
optional_inputs: list or dict
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
`choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as
`choose_from`.
Returns
-------
tuple of torch.Tensor and torch.Tensor or torch.Tensor
"""
optional_input_list
=
optional_inputs
if
isinstance
(
optional_inputs
,
dict
):
optional_input_list
=
[
optional_inputs
[
tag
]
for
tag
in
self
.
choose_from
]
assert
isinstance
(
optional_input_list
,
list
),
"Optional input list must be a list"
assert
len
(
optional_inputs
)
==
self
.
n_candidates
,
\
assert
len
(
optional_inputs
)
==
self
.
n_candidates
,
\
"Length of the input list must be equal to number of candidates."
"Length of the input list must be equal to number of candidates."
if
tags
is
None
:
out
,
mask
=
self
.
mutator
.
on_forward_input_choice
(
self
,
optional_input_list
)
tags
=
[
""
]
*
self
.
n_candidates
else
:
assert
len
(
tags
)
==
self
.
n_candidates
,
"Length of tags must be equal to number of candidates."
out
,
mask
=
self
.
mutator
.
on_forward_input_choice
(
self
,
optional_inputs
,
tags
)
if
self
.
return_mask
:
if
self
.
return_mask
:
return
out
,
mask
return
out
,
mask
return
out
return
out
def
similar
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
\
self
.
n_candidates
==
other
.
n_candidates
and
self
.
n_selected
and
other
.
n_selected
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
77e91e8b
from
contextlib
import
contextmanager
import
torch
import
torch
import
torch.nn
as
nn
from
nni.nas.pytorch.base_mutator
import
BaseMutator
from
nni.nas.pytorch.base_mutator
import
BaseMutator
class
Mutator
(
BaseMutator
,
nn
.
Module
):
class
Mutator
(
BaseMutator
):
def
export
(
self
):
def
__init__
(
self
,
model
):
if
self
.
_in_forward_pass
:
super
().
__init__
(
model
)
raise
RuntimeError
(
"Still in forward pass. Exporting might induce incompleteness."
)
if
not
self
.
_cache
:
raise
RuntimeError
(
"No running history found. You need to call your model at least once before exporting. "
"You might also want to check if there are no valid mutables in your model."
)
return
self
.
_cache
@
contextmanager
def
forward_pass
(
self
):
self
.
_in_forward_pass
=
True
self
.
_cache
=
dict
()
self
.
_cache
=
dict
()
self
.
before_pass
()
try
:
yield
self
finally
:
self
.
after_pass
()
self
.
_in_forward_pass
=
False
def
before_pass
(
self
):
def
sample_search
(
self
):
pass
"""
Override to implement this method to iterate over mutables and make decisions.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise
NotImplementedError
def
sample_final
(
self
):
"""
Override to implement this method to iterate over mutables and make decisions that is final
for export and retraining.
def
after_pass
(
self
):
Returns
pass
-------
dict
A mapping from key of mutables to decisions.
"""
raise
NotImplementedError
def
_check_in_forward_pass
(
self
):
def
reset
(
self
):
if
not
hasattr
(
self
,
"_in_forward_pass"
)
or
not
self
.
_in_forward_pass
:
"""
raise
ValueError
(
"Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call "
Reset the mutator by call the `sample_search` to resample (for search).
"super().before_pass() and after_pass() in your override method?"
)
Returns
-------
None
"""
self
.
_cache
=
self
.
sample_search
()
def
export
(
self
):
"""
Resample (for final) and return results.
Returns
-------
dict
"""
return
self
.
sample_final
()
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
"""
"""
Callback of layer choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
...
@@ -54,18 +67,17 @@ class Mutator(BaseMutator, nn.Module):
...
@@ -54,18 +67,17 @@ class Mutator(BaseMutator, nn.Module):
-------
-------
tuple of torch.Tensor and torch.Tensor
tuple of torch.Tensor and torch.Tensor
"""
"""
self
.
_check_in_forward_pass
()
def
_map_fn
(
op
,
*
inputs
):
def
_map_fn
(
op
,
*
inputs
):
return
op
(
*
inputs
)
return
op
(
*
inputs
)
mask
=
self
.
_cache
.
setdefault
(
mutable
.
key
,
self
.
on_calc_layer_choice_mask
(
mutable
))
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
len
(
mutable
.
choices
)
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
*
inputs
)
for
choice
in
mutable
.
choices
],
mask
)
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
*
inputs
)
for
choice
in
mutable
.
choices
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
,
tags
):
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
"""
"""
Callback of input choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_input_choice_mask` with `tags`
On default, this method calls :meth:`on_calc_input_choice_mask` with `tags`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
...
@@ -81,48 +93,11 @@ class Mutator(BaseMutator, nn.Module):
...
@@ -81,48 +93,11 @@ class Mutator(BaseMutator, nn.Module):
-------
-------
tuple of torch.Tensor and torch.Tensor
tuple of torch.Tensor and torch.Tensor
"""
"""
self
.
_check_in_forward_pass
(
)
mask
=
self
.
_get_decision
(
mutable
)
m
as
k
=
self
.
_cache
.
setdefault
(
mutable
.
key
,
self
.
on_calc_input_choice_mask
(
mutable
,
tags
))
as
sert
len
(
mask
)
==
mutable
.
n_candidates
out
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,)
for
t
in
tensor_list
],
mask
)
out
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,)
for
t
in
tensor_list
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_calc_layer_choice_mask
(
self
,
mutable
):
"""
Recommended to override. Calculate a mask tensor for a layer choice.
Parameters
----------
mutable: LayerChoice
Corresponding layer choice object.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise
NotImplementedError
(
"Layer choice mask calculation must be implemented"
)
def
on_calc_input_choice_mask
(
self
,
mutable
,
tags
):
"""
Recommended to override. Calculate a mask tensor for a input choice.
Parameters
----------
mutable: InputChoice
Corresponding input choice object.
tags: list of string
The name of labels of input tensors given by user. Usually it's a
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise
NotImplementedError
(
"Input choice mask calculation must be implemented"
)
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
if
"BoolTensor"
in
mask
.
type
():
if
"BoolTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
...
@@ -146,3 +121,20 @@ class Mutator(BaseMutator, nn.Module):
...
@@ -146,3 +121,20 @@ class Mutator(BaseMutator, nn.Module):
if
reduction_type
==
"concat"
:
if
reduction_type
==
"concat"
:
return
torch
.
cat
(
tensor_list
,
dim
=
1
)
return
torch
.
cat
(
tensor_list
,
dim
=
1
)
raise
ValueError
(
"Unrecognized reduction policy:
\"
{}
\"
"
.
format
(
reduction_type
))
raise
ValueError
(
"Unrecognized reduction policy:
\"
{}
\"
"
.
format
(
reduction_type
))
def
_get_decision
(
self
,
mutable
):
"""
By default, this method checks whether `mutable.key` is already in the decision cache,
and returns the result without double-check.
Parameters
----------
mutable: Mutable
Returns
-------
any
"""
if
mutable
.
key
not
in
self
.
_cache
:
raise
ValueError
(
"
\"
{}
\"
not found in decision cache."
.
format
(
mutable
.
key
))
return
self
.
_cache
[
mutable
.
key
]
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
View file @
77e91e8b
...
@@ -11,14 +11,14 @@ from nni.nas.pytorch.mutables import LayerChoice
...
@@ -11,14 +11,14 @@ from nni.nas.pytorch.mutables import LayerChoice
class
PdartsMutator
(
DartsMutator
):
class
PdartsMutator
(
DartsMutator
):
def
__init__
(
self
,
model
,
pdarts_epoch_index
,
pdarts_num_to_drop
,
switches
=
None
):
def
__init__
(
self
,
pdarts_epoch_index
,
pdarts_num_to_drop
,
switches
=
None
):
self
.
pdarts_epoch_index
=
pdarts_epoch_index
self
.
pdarts_epoch_index
=
pdarts_epoch_index
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
self
.
switches
=
switches
self
.
switches
=
switches
super
(
PdartsMutator
,
self
).
__init__
(
model
)
super
(
PdartsMutator
,
self
).
__init__
()
def
before_build
(
self
,
model
):
def
before_build
(
self
):
self
.
choices
=
nn
.
ParameterDict
()
self
.
choices
=
nn
.
ParameterDict
()
if
self
.
switches
is
None
:
if
self
.
switches
is
None
:
self
.
switches
=
{}
self
.
switches
=
{}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment