Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d1d10de7
Unverified
Commit
d1d10de7
authored
Nov 14, 2019
by
Chi Song
Committed by
GitHub
Nov 14, 2019
Browse files
pdarts implementation (export is not included) (#1730)
parent
d43fbe82
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
421 additions
and
25 deletions
+421
-25
.gitignore
.gitignore
+1
-0
examples/nas/.gitignore
examples/nas/.gitignore
+1
-1
examples/nas/darts/search.py
examples/nas/darts/search.py
+5
-7
examples/nas/pdarts/.gitignore
examples/nas/pdarts/.gitignore
+2
-0
examples/nas/pdarts/datasets.py
examples/nas/pdarts/datasets.py
+25
-0
examples/nas/pdarts/main.py
examples/nas/pdarts/main.py
+65
-0
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
+2
-0
src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py
src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py
+69
-0
src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py
src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py
+73
-0
src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py
src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py
+13
-10
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+2
-1
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
+2
-2
src/sdk/pynni/nni/nas/pytorch/modules.py
src/sdk/pynni/nni/nas/pytorch/modules.py
+9
-0
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+4
-4
src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py
src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
+93
-0
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
+54
-0
No files found.
.gitignore
View file @
d1d10de7
...
...
@@ -80,6 +80,7 @@ venv.bak/
# VSCode
.vscode
.vs
# In case you place source code in ~/nni/
/experiments
examples/nas/.gitignore
View file @
d1d10de7
examples/nas/darts/search.py
View file @
d1d10de7
from
argparse
import
ArgumentParser
import
datasets
import
torch
import
torch.nn
as
nn
from
model
import
SearchCNN
from
nni.nas.pytorch.darts
import
DartsTrainer
import
datasets
from
nni.nas.pytorch.darts
import
CnnNetwork
,
DartsTrainer
from
utils
import
accuracy
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"darts"
)
parser
.
add_argument
(
"--layers"
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
"--nodes"
,
default
=
2
,
type
=
int
)
parser
.
add_argument
(
"--layers"
,
default
=
5
,
type
=
int
)
parser
.
add_argument
(
"--nodes"
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
1
,
type
=
int
)
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
model
=
SearchCNN
(
3
,
16
,
10
,
args
.
layers
,
n_nodes
=
args
.
nodes
)
model
=
CnnNetwork
(
3
,
16
,
10
,
args
.
layers
,
n_nodes
=
args
.
nodes
)
criterion
=
nn
.
CrossEntropyLoss
()
optim
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
0.025
,
momentum
=
0.9
,
weight_decay
=
3.0E-4
)
...
...
examples/nas/pdarts/.gitignore
0 → 100644
View file @
d1d10de7
data/*
log
examples/nas/pdarts/datasets.py
0 → 100644
View file @
d1d10de7
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
def
get_dataset
(
cls
):
MEAN
=
[
0.49139968
,
0.48215827
,
0.44653124
]
STD
=
[
0.24703233
,
0.24348505
,
0.26158768
]
transf
=
[
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
()
]
normalize
=
[
transforms
.
ToTensor
(),
transforms
.
Normalize
(
MEAN
,
STD
)
]
train_transform
=
transforms
.
Compose
(
transf
+
normalize
)
valid_transform
=
transforms
.
Compose
(
normalize
)
if
cls
==
"cifar10"
:
dataset_train
=
CIFAR10
(
root
=
"./data"
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
dataset_valid
=
CIFAR10
(
root
=
"./data"
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
else
:
raise
NotImplementedError
return
dataset_train
,
dataset_valid
examples/nas/pdarts/main.py
0 → 100644
View file @
d1d10de7
from
argparse
import
ArgumentParser
import
datasets
import
torch
import
torch.nn
as
nn
import
nni.nas.pytorch
as
nas
from
nni.nas.pytorch.pdarts
import
PdartsTrainer
from
nni.nas.pytorch.darts
import
CnnNetwork
,
CnnCell
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
""" Computes the precision@k for the specified values of k """
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
# one-hot case
if
target
.
ndimension
()
>
1
:
target
=
target
.
max
(
1
)[
1
]
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
dict
()
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
)
res
[
"acc{}"
.
format
(
k
)]
=
correct_k
.
mul_
(
1.0
/
batch_size
).
item
()
return
res
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"darts"
)
parser
.
add_argument
(
"--layers"
,
default
=
5
,
type
=
int
)
parser
.
add_argument
(
'--add_layers'
,
action
=
'append'
,
default
=
[
0
,
6
,
12
],
help
=
'add layers'
)
parser
.
add_argument
(
"--nodes"
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
1
,
type
=
int
)
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
def
model_creator
(
layers
,
n_nodes
):
model
=
CnnNetwork
(
3
,
16
,
10
,
layers
,
n_nodes
=
n_nodes
,
cell_type
=
CnnCell
)
loss
=
nn
.
CrossEntropyLoss
()
model_optim
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
0.025
,
momentum
=
0.9
,
weight_decay
=
3.0E-4
)
n_epochs
=
50
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
model_optim
,
n_epochs
,
eta_min
=
0.001
)
return
model
,
loss
,
model_optim
,
lr_scheduler
trainer
=
PdartsTrainer
(
model_creator
,
metrics
=
lambda
output
,
target
:
accuracy
(
output
,
target
,
topk
=
(
1
,)),
num_epochs
=
50
,
pdarts_num_layers
=
[
0
,
6
,
12
],
pdarts_num_to_drop
=
[
3
,
2
,
2
],
dataset_train
=
dataset_train
,
dataset_valid
=
dataset_valid
,
layers
=
args
.
layers
,
n_nodes
=
args
.
nodes
,
batch_size
=
args
.
batch_size
,
log_frequency
=
args
.
log_frequency
)
trainer
.
train
()
trainer
.
export
()
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
View file @
d1d10de7
from
.mutator
import
DartsMutator
from
.trainer
import
DartsTrainer
from
.cnn_cell
import
CnnCell
from
.cnn_network
import
CnnNetwork
src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py
0 → 100644
View file @
d1d10de7
import
torch
import
torch.nn
as
nn
import
nni.nas.pytorch
as
nas
from
nni.nas.pytorch.modules
import
RankedModule
from
.cnn_ops
import
OPS
,
PRIMITIVES
,
FactorizedReduce
,
StdConv
class
CnnCell
(
RankedModule
):
"""
Cell for search.
"""
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super
(
CnnCell
,
self
).
__init__
(
rank
=
1
,
reduction
=
reduction
)
self
.
n_nodes
=
n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if
reduction_p
:
self
.
preproc0
=
FactorizedReduce
(
channels_pp
,
channels
,
affine
=
False
)
else
:
self
.
preproc0
=
StdConv
(
channels_pp
,
channels
,
1
,
1
,
0
,
affine
=
False
)
self
.
preproc1
=
StdConv
(
channels_p
,
channels
,
1
,
1
,
0
,
affine
=
False
)
# generate dag
self
.
mutable_ops
=
nn
.
ModuleList
()
for
depth
in
range
(
self
.
n_nodes
):
self
.
mutable_ops
.
append
(
nn
.
ModuleList
())
for
i
in
range
(
2
+
depth
):
# include 2 input nodes
# reduction should be used only for input node
stride
=
2
if
reduction
and
i
<
2
else
1
m_ops
=
[]
for
primitive
in
PRIMITIVES
:
op
=
OPS
[
primitive
](
channels
,
stride
,
False
)
m_ops
.
append
(
op
)
op
=
nas
.
mutables
.
LayerChoice
(
m_ops
,
key
=
"r{}_d{}_i{}"
.
format
(
reduction
,
depth
,
i
))
self
.
mutable_ops
[
depth
].
append
(
op
)
def
forward
(
self
,
s0
,
s1
):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors
=
[
self
.
preproc0
(
s0
),
self
.
preproc1
(
s1
)]
for
ops
in
self
.
mutable_ops
:
assert
len
(
ops
)
==
len
(
tensors
)
cur_tensor
=
sum
(
op
(
tensor
)
for
op
,
tensor
in
zip
(
ops
,
tensors
))
tensors
.
append
(
cur_tensor
)
output
=
torch
.
cat
(
tensors
[
2
:],
dim
=
1
)
return
output
examples/nas/darts/model
.py
→
src/sdk/pynni/nni/nas/pytorch/darts/cnn_network
.py
View file @
d1d10de7
import
torch
import
torch.nn
as
nn
import
ops
from
nni.nas
import
pytorch
as
nas
class
SearchCell
(
nn
.
Module
):
"""
Cell for search.
"""
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super
().
__init__
()
self
.
reduction
=
reduction
self
.
n_nodes
=
n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if
reduction_p
:
self
.
preproc0
=
ops
.
FactorizedReduce
(
channels_pp
,
channels
,
affine
=
False
)
else
:
self
.
preproc0
=
ops
.
StdConv
(
channels_pp
,
channels
,
1
,
1
,
0
,
affine
=
False
)
self
.
preproc1
=
ops
.
StdConv
(
channels_p
,
channels
,
1
,
1
,
0
,
affine
=
False
)
# generate dag
self
.
mutable_ops
=
nn
.
ModuleList
()
for
depth
in
range
(
self
.
n_nodes
):
self
.
mutable_ops
.
append
(
nn
.
ModuleList
())
for
i
in
range
(
2
+
depth
):
# include 2 input nodes
# reduction should be used only for input node
stride
=
2
if
reduction
and
i
<
2
else
1
op
=
nas
.
mutables
.
LayerChoice
([
ops
.
PoolBN
(
'max'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
PoolBN
(
'avg'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
Identity
()
if
stride
==
1
else
ops
.
FactorizedReduce
(
channels
,
channels
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
),
ops
.
Zero
(
stride
)],
key
=
"r{}_d{}_i{}"
.
format
(
reduction
,
depth
,
i
))
self
.
mutable_ops
[
depth
].
append
(
op
)
def
forward
(
self
,
s0
,
s1
):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors
=
[
self
.
preproc0
(
s0
),
self
.
preproc1
(
s1
)]
for
ops
in
self
.
mutable_ops
:
assert
len
(
ops
)
==
len
(
tensors
)
cur_tensor
=
sum
(
op
(
tensor
)
for
op
,
tensor
in
zip
(
ops
,
tensors
))
tensors
.
append
(
cur_tensor
)
import
torch.nn
as
nn
output
=
torch
.
cat
(
tensors
[
2
:],
dim
=
1
)
return
output
from
.cnn_cell
import
CnnCell
class
SearchCNN
(
nn
.
Module
):
class
CnnNetwork
(
nn
.
Module
):
"""
Search CNN model
"""
def
__init__
(
self
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
stem_multiplier
=
3
):
def
__init__
(
self
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
stem_multiplier
=
3
,
cell_type
=
CnnCell
):
"""
Initializing a search channelsNN.
...
...
@@ -121,7 +53,7 @@ class SearchCNN(nn.Module):
c_cur
*=
2
reduction
=
True
cell
=
SearchCell
(
n_nodes
,
channels_pp
,
channels_p
,
c_cur
,
reduction_p
,
reduction
)
cell
=
cell_type
(
n_nodes
,
channels_pp
,
channels_p
,
c_cur
,
reduction_p
,
reduction
)
self
.
cells
.
append
(
cell
)
c_cur_out
=
c_cur
*
n_nodes
channels_pp
,
channels_p
=
channels_p
,
c_cur_out
...
...
examples/nas
/darts/ops.py
→
src/sdk/pynni/nni/nas/pytorch
/darts/
cnn_
ops.py
View file @
d1d10de7
import
torch
import
torch.nn
as
nn
PRIMITIVES
=
[
'none'
,
'max_pool_3x3'
,
'avg_pool_3x3'
,
'skip_connect'
,
# identity
...
...
@@ -10,15 +10,13 @@ PRIMITIVES = [
'sep_conv_5x5'
,
'dil_conv_3x3'
,
'dil_conv_5x5'
,
'none'
]
OPS
=
{
'none'
:
lambda
C
,
stride
,
affine
:
Zero
(
stride
),
'avg_pool_3x3'
:
lambda
C
,
stride
,
affine
:
PoolBN
(
'avg'
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'max_pool_3x3'
:
lambda
C
,
stride
,
affine
:
PoolBN
(
'max'
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'skip_connect'
:
lambda
C
,
stride
,
affine
:
\
Identity
()
if
stride
==
1
else
FactorizedReduce
(
C
,
C
,
affine
=
affine
),
'skip_connect'
:
lambda
C
,
stride
,
affine
:
Identity
()
if
stride
==
1
else
FactorizedReduce
(
C
,
C
,
affine
=
affine
),
'sep_conv_3x3'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'sep_conv_5x5'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
5
,
stride
,
2
,
affine
=
affine
),
'sep_conv_7x7'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
7
,
stride
,
3
,
affine
=
affine
),
...
...
@@ -60,6 +58,7 @@ class PoolBN(nn.Module):
"""
AvgPool or MaxPool - BN
"""
def
__init__
(
self
,
pool_type
,
C
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
"""
Args:
...
...
@@ -85,6 +84,7 @@ class StdConv(nn.Module):
""" Standard conv
ReLU - Conv - BN
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
...
...
@@ -101,6 +101,7 @@ class FacConv(nn.Module):
""" Factorized conv
ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_length
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
...
...
@@ -120,12 +121,12 @@ class DilConv(nn.Module):
If dilation == 2, 3x3 conv => 5x5 receptive field
5x5 conv => 9x9 receptive field
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
C_in
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
C_in
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
)
...
...
@@ -138,6 +139,7 @@ class SepConv(nn.Module):
""" Depthwise separable conv
DilConv(dilation=1) * 2
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
...
...
@@ -172,6 +174,7 @@ class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise(stride=2).
"""
def
__init__
(
self
,
C_in
,
C_out
,
affine
=
True
):
super
().
__init__
()
self
.
relu
=
nn
.
ReLU
()
...
...
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
d1d10de7
...
...
@@ -94,6 +94,7 @@ class DartsTrainer(Trainer):
with
torch
.
no_grad
():
for
step
,
(
X
,
y
)
in
enumerate
(
self
.
valid_loader
):
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
with
self
.
mutator
.
forward_pass
():
logits
=
self
.
model
(
X
)
metrics
=
self
.
metrics
(
logits
,
y
)
meters
.
update
(
metrics
)
...
...
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
View file @
d1d10de7
...
...
@@ -40,7 +40,7 @@ class EnasMutator(PyTorchMutator):
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
t
ensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
requires_grad
=
False
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
T
ensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
requires_grad
=
False
)
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
()
def
after_build
(
self
,
model
):
...
...
src/sdk/pynni/nni/nas/pytorch/modules.py
0 → 100644
View file @
d1d10de7
from
torch
import
nn
as
nn
class
RankedModule
(
nn
.
Module
):
def
__init__
(
self
,
rank
=
None
,
reduction
=
False
):
super
(
RankedModule
,
self
).
__init__
()
self
.
rank
=
rank
self
.
reduction
=
reduction
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
d1d10de7
...
...
@@ -56,9 +56,6 @@ class PyTorchMutable(nn.Module):
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
def
__repr__
(
self
):
return
"{} ({})"
.
format
(
self
.
name
,
self
.
key
)
class
MutableScope
(
PyTorchMutable
):
"""
...
...
@@ -85,6 +82,9 @@ class LayerChoice(PyTorchMutable):
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
def
__len__
(
self
):
return
self
.
length
def
forward
(
self
,
*
inputs
):
out
,
mask
=
self
.
mutator
.
on_forward
(
self
,
*
inputs
)
if
self
.
return_mask
:
...
...
src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py
0 → 100644
View file @
d1d10de7
from
.trainer
import
PdartsTrainer
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
0 → 100644
View file @
d1d10de7
import
copy
import
numpy
as
np
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
nni.nas.pytorch.darts
import
DartsMutator
from
nni.nas.pytorch.mutables
import
LayerChoice
class
PdartsMutator
(
DartsMutator
):
def
__init__
(
self
,
model
,
pdarts_epoch_index
,
pdarts_num_to_drop
,
switches
=
None
):
self
.
pdarts_epoch_index
=
pdarts_epoch_index
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
self
.
switches
=
switches
super
(
PdartsMutator
,
self
).
__init__
(
model
)
def
before_build
(
self
,
model
):
self
.
choices
=
nn
.
ParameterDict
()
if
self
.
switches
is
None
:
self
.
switches
=
{}
def
named_mutables
(
self
,
model
):
key2module
=
dict
()
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
LayerChoice
):
key2module
[
module
.
key
]
=
module
yield
name
,
module
,
True
def
drop_paths
(
self
):
for
key
in
self
.
switches
:
prob
=
F
.
softmax
(
self
.
choices
[
key
],
dim
=-
1
).
data
.
cpu
().
numpy
()
switches
=
self
.
switches
[
key
]
idxs
=
[]
for
j
in
range
(
len
(
switches
)):
if
switches
[
j
]:
idxs
.
append
(
j
)
if
self
.
pdarts_epoch_index
==
len
(
self
.
pdarts_num_to_drop
)
-
1
:
# for the last stage, drop all Zero operations
drop
=
self
.
get_min_k_no_zero
(
prob
,
idxs
,
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
])
else
:
drop
=
self
.
get_min_k
(
prob
,
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
])
for
idx
in
drop
:
switches
[
idxs
[
idx
]]
=
False
return
self
.
switches
def
on_init_layer_choice
(
self
,
mutable
:
LayerChoice
):
switches
=
self
.
switches
.
get
(
mutable
.
key
,
[
True
for
j
in
range
(
mutable
.
length
)])
for
index
in
range
(
len
(
switches
)
-
1
,
-
1
,
-
1
):
if
switches
[
index
]
==
False
:
del
(
mutable
.
choices
[
index
])
mutable
.
length
-=
1
self
.
switches
[
mutable
.
key
]
=
switches
self
.
choices
[
mutable
.
key
]
=
nn
.
Parameter
(
1.0E-3
*
torch
.
randn
(
mutable
.
length
))
def
on_calc_layer_choice_mask
(
self
,
mutable
:
LayerChoice
):
return
F
.
softmax
(
self
.
choices
[
mutable
.
key
],
dim
=-
1
)
def
get_min_k
(
self
,
input_in
,
k
):
index
=
[]
for
_
in
range
(
k
):
idx
=
np
.
argmin
(
input
)
index
.
append
(
idx
)
return
index
def
get_min_k_no_zero
(
self
,
w_in
,
idxs
,
k
):
w
=
copy
.
deepcopy
(
w_in
)
index
=
[]
if
0
in
idxs
:
zf
=
True
else
:
zf
=
False
if
zf
:
w
=
w
[
1
:]
index
.
append
(
0
)
k
=
k
-
1
for
_
in
range
(
k
):
idx
=
np
.
argmin
(
w
)
w
[
idx
]
=
1
if
zf
:
idx
=
idx
+
1
index
.
append
(
idx
)
return
index
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
0 → 100644
View file @
d1d10de7
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.trainer
import
Trainer
from
.mutator
import
PdartsMutator
class
PdartsTrainer
(
Trainer
):
def
__init__
(
self
,
model_creator
,
metrics
,
num_epochs
,
dataset_train
,
dataset_valid
,
layers
=
5
,
n_nodes
=
4
,
pdarts_num_layers
=
[
0
,
6
,
12
],
pdarts_num_to_drop
=
[
3
,
2
,
2
],
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
):
self
.
model_creator
=
model_creator
self
.
layers
=
layers
self
.
n_nodes
=
n_nodes
self
.
pdarts_num_layers
=
pdarts_num_layers
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
self
.
pdarts_epoch
=
len
(
pdarts_num_to_drop
)
self
.
darts_parameters
=
{
"metrics"
:
metrics
,
"num_epochs"
:
num_epochs
,
"dataset_train"
:
dataset_train
,
"dataset_valid"
:
dataset_valid
,
"batch_size"
:
batch_size
,
"workers"
:
workers
,
"device"
:
device
,
"log_frequency"
:
log_frequency
}
def
train
(
self
):
layers
=
self
.
layers
n_nodes
=
self
.
n_nodes
switches
=
None
for
epoch
in
range
(
self
.
pdarts_epoch
):
layers
=
self
.
layers
+
self
.
pdarts_num_layers
[
epoch
]
model
,
loss
,
model_optim
,
lr_scheduler
=
self
.
model_creator
(
layers
,
n_nodes
)
mutator
=
PdartsMutator
(
model
,
epoch
,
self
.
pdarts_num_to_drop
,
switches
)
self
.
trainer
=
DartsTrainer
(
model
,
loss
=
loss
,
model_optim
=
model_optim
,
lr_scheduler
=
lr_scheduler
,
mutator
=
mutator
,
**
self
.
darts_parameters
)
print
(
"start pdrats training %s..."
%
epoch
)
self
.
trainer
.
train
()
# with open('log/parameters_%d.txt' % epoch, "w") as f:
# f.write(str(model.parameters))
switches
=
mutator
.
drop_paths
()
def
export
(
self
):
if
(
self
.
trainer
is
not
None
)
and
hasattr
(
self
.
trainer
,
"export"
):
self
.
trainer
.
export
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment