Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3ddab980
Commit
3ddab980
authored
Nov 18, 2019
by
quzha
Browse files
Merge branch 'dev-nas-refactor' of github.com:Microsoft/nni into dev-nas-refactor
parents
594924a9
d1d10de7
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
421 additions
and
25 deletions
+421
-25
.gitignore
.gitignore
+1
-0
examples/nas/.gitignore
examples/nas/.gitignore
+1
-1
examples/nas/darts/search.py
examples/nas/darts/search.py
+5
-7
examples/nas/pdarts/.gitignore
examples/nas/pdarts/.gitignore
+2
-0
examples/nas/pdarts/datasets.py
examples/nas/pdarts/datasets.py
+25
-0
examples/nas/pdarts/main.py
examples/nas/pdarts/main.py
+65
-0
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
+2
-0
src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py
src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py
+69
-0
src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py
src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py
+73
-0
src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py
src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py
+13
-10
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+2
-1
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
+2
-2
src/sdk/pynni/nni/nas/pytorch/modules.py
src/sdk/pynni/nni/nas/pytorch/modules.py
+9
-0
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+4
-4
src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py
src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
+93
-0
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
+54
-0
No files found.
.gitignore
View file @
3ddab980
...
...
@@ -80,6 +80,7 @@ venv.bak/
# VSCode
.vscode
.vs
# In case you place source code in ~/nni/
/experiments
examples/nas/.gitignore
View file @
3ddab980
data
data
examples/nas/darts/search.py
View file @
3ddab980
from
argparse
import
ArgumentParser
import
datasets
import
torch
import
torch.nn
as
nn
from
model
import
SearchCNN
from
nni.nas.pytorch.darts
import
DartsTrainer
import
datasets
from
nni.nas.pytorch.darts
import
CnnNetwork
,
DartsTrainer
from
utils
import
accuracy
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"darts"
)
parser
.
add_argument
(
"--layers"
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
"--nodes"
,
default
=
2
,
type
=
int
)
parser
.
add_argument
(
"--layers"
,
default
=
5
,
type
=
int
)
parser
.
add_argument
(
"--nodes"
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
1
,
type
=
int
)
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
model
=
SearchCNN
(
3
,
16
,
10
,
args
.
layers
,
n_nodes
=
args
.
nodes
)
model
=
CnnNetwork
(
3
,
16
,
10
,
args
.
layers
,
n_nodes
=
args
.
nodes
)
criterion
=
nn
.
CrossEntropyLoss
()
optim
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
0.025
,
momentum
=
0.9
,
weight_decay
=
3.0E-4
)
...
...
examples/nas/pdarts/.gitignore
0 → 100644
View file @
3ddab980
data/*
log
examples/nas/pdarts/datasets.py
0 → 100644
View file @
3ddab980
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
def
get_dataset
(
cls
):
MEAN
=
[
0.49139968
,
0.48215827
,
0.44653124
]
STD
=
[
0.24703233
,
0.24348505
,
0.26158768
]
transf
=
[
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
()
]
normalize
=
[
transforms
.
ToTensor
(),
transforms
.
Normalize
(
MEAN
,
STD
)
]
train_transform
=
transforms
.
Compose
(
transf
+
normalize
)
valid_transform
=
transforms
.
Compose
(
normalize
)
if
cls
==
"cifar10"
:
dataset_train
=
CIFAR10
(
root
=
"./data"
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
dataset_valid
=
CIFAR10
(
root
=
"./data"
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
else
:
raise
NotImplementedError
return
dataset_train
,
dataset_valid
examples/nas/pdarts/main.py
0 → 100644
View file @
3ddab980
from
argparse
import
ArgumentParser
import
datasets
import
torch
import
torch.nn
as
nn
import
nni.nas.pytorch
as
nas
from
nni.nas.pytorch.pdarts
import
PdartsTrainer
from
nni.nas.pytorch.darts
import
CnnNetwork
,
CnnCell
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
""" Computes the precision@k for the specified values of k """
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
# one-hot case
if
target
.
ndimension
()
>
1
:
target
=
target
.
max
(
1
)[
1
]
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
dict
()
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
)
res
[
"acc{}"
.
format
(
k
)]
=
correct_k
.
mul_
(
1.0
/
batch_size
).
item
()
return
res
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"darts"
)
parser
.
add_argument
(
"--layers"
,
default
=
5
,
type
=
int
)
parser
.
add_argument
(
'--add_layers'
,
action
=
'append'
,
default
=
[
0
,
6
,
12
],
help
=
'add layers'
)
parser
.
add_argument
(
"--nodes"
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
1
,
type
=
int
)
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
def
model_creator
(
layers
,
n_nodes
):
model
=
CnnNetwork
(
3
,
16
,
10
,
layers
,
n_nodes
=
n_nodes
,
cell_type
=
CnnCell
)
loss
=
nn
.
CrossEntropyLoss
()
model_optim
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
0.025
,
momentum
=
0.9
,
weight_decay
=
3.0E-4
)
n_epochs
=
50
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
model_optim
,
n_epochs
,
eta_min
=
0.001
)
return
model
,
loss
,
model_optim
,
lr_scheduler
trainer
=
PdartsTrainer
(
model_creator
,
metrics
=
lambda
output
,
target
:
accuracy
(
output
,
target
,
topk
=
(
1
,)),
num_epochs
=
50
,
pdarts_num_layers
=
[
0
,
6
,
12
],
pdarts_num_to_drop
=
[
3
,
2
,
2
],
dataset_train
=
dataset_train
,
dataset_valid
=
dataset_valid
,
layers
=
args
.
layers
,
n_nodes
=
args
.
nodes
,
batch_size
=
args
.
batch_size
,
log_frequency
=
args
.
log_frequency
)
trainer
.
train
()
trainer
.
export
()
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
View file @
3ddab980
from
.mutator
import
DartsMutator
from
.trainer
import
DartsTrainer
from
.cnn_cell
import
CnnCell
from
.cnn_network
import
CnnNetwork
src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py
0 → 100644
View file @
3ddab980
import
torch
import
torch.nn
as
nn
import
nni.nas.pytorch
as
nas
from
nni.nas.pytorch.modules
import
RankedModule
from
.cnn_ops
import
OPS
,
PRIMITIVES
,
FactorizedReduce
,
StdConv
class
CnnCell
(
RankedModule
):
"""
Cell for search.
"""
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super
(
CnnCell
,
self
).
__init__
(
rank
=
1
,
reduction
=
reduction
)
self
.
n_nodes
=
n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if
reduction_p
:
self
.
preproc0
=
FactorizedReduce
(
channels_pp
,
channels
,
affine
=
False
)
else
:
self
.
preproc0
=
StdConv
(
channels_pp
,
channels
,
1
,
1
,
0
,
affine
=
False
)
self
.
preproc1
=
StdConv
(
channels_p
,
channels
,
1
,
1
,
0
,
affine
=
False
)
# generate dag
self
.
mutable_ops
=
nn
.
ModuleList
()
for
depth
in
range
(
self
.
n_nodes
):
self
.
mutable_ops
.
append
(
nn
.
ModuleList
())
for
i
in
range
(
2
+
depth
):
# include 2 input nodes
# reduction should be used only for input node
stride
=
2
if
reduction
and
i
<
2
else
1
m_ops
=
[]
for
primitive
in
PRIMITIVES
:
op
=
OPS
[
primitive
](
channels
,
stride
,
False
)
m_ops
.
append
(
op
)
op
=
nas
.
mutables
.
LayerChoice
(
m_ops
,
key
=
"r{}_d{}_i{}"
.
format
(
reduction
,
depth
,
i
))
self
.
mutable_ops
[
depth
].
append
(
op
)
def
forward
(
self
,
s0
,
s1
):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors
=
[
self
.
preproc0
(
s0
),
self
.
preproc1
(
s1
)]
for
ops
in
self
.
mutable_ops
:
assert
len
(
ops
)
==
len
(
tensors
)
cur_tensor
=
sum
(
op
(
tensor
)
for
op
,
tensor
in
zip
(
ops
,
tensors
))
tensors
.
append
(
cur_tensor
)
output
=
torch
.
cat
(
tensors
[
2
:],
dim
=
1
)
return
output
examples/nas/darts/model
.py
→
src/sdk/pynni/nni/nas/pytorch/darts/cnn_network
.py
View file @
3ddab980
import
torch
import
torch.nn
as
nn
import
ops
from
nni.nas
import
pytorch
as
nas
class
SearchCell
(
nn
.
Module
):
"""
Cell for search.
"""
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super
().
__init__
()
self
.
reduction
=
reduction
self
.
n_nodes
=
n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if
reduction_p
:
self
.
preproc0
=
ops
.
FactorizedReduce
(
channels_pp
,
channels
,
affine
=
False
)
else
:
self
.
preproc0
=
ops
.
StdConv
(
channels_pp
,
channels
,
1
,
1
,
0
,
affine
=
False
)
self
.
preproc1
=
ops
.
StdConv
(
channels_p
,
channels
,
1
,
1
,
0
,
affine
=
False
)
# generate dag
self
.
mutable_ops
=
nn
.
ModuleList
()
for
depth
in
range
(
self
.
n_nodes
):
self
.
mutable_ops
.
append
(
nn
.
ModuleList
())
for
i
in
range
(
2
+
depth
):
# include 2 input nodes
# reduction should be used only for input node
stride
=
2
if
reduction
and
i
<
2
else
1
op
=
nas
.
mutables
.
LayerChoice
([
ops
.
PoolBN
(
'max'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
PoolBN
(
'avg'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
Identity
()
if
stride
==
1
else
ops
.
FactorizedReduce
(
channels
,
channels
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
),
ops
.
Zero
(
stride
)],
key
=
"r{}_d{}_i{}"
.
format
(
reduction
,
depth
,
i
))
self
.
mutable_ops
[
depth
].
append
(
op
)
def
forward
(
self
,
s0
,
s1
):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors
=
[
self
.
preproc0
(
s0
),
self
.
preproc1
(
s1
)]
for
ops
in
self
.
mutable_ops
:
assert
len
(
ops
)
==
len
(
tensors
)
cur_tensor
=
sum
(
op
(
tensor
)
for
op
,
tensor
in
zip
(
ops
,
tensors
))
tensors
.
append
(
cur_tensor
)
import
torch.nn
as
nn
output
=
torch
.
cat
(
tensors
[
2
:],
dim
=
1
)
return
output
from
.cnn_cell
import
CnnCell
class
SearchCNN
(
nn
.
Module
):
class
CnnNetwork
(
nn
.
Module
):
"""
Search CNN model
"""
def
__init__
(
self
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
stem_multiplier
=
3
):
def
__init__
(
self
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
stem_multiplier
=
3
,
cell_type
=
CnnCell
):
"""
Initializing a search channelsNN.
...
...
@@ -121,7 +53,7 @@ class SearchCNN(nn.Module):
c_cur
*=
2
reduction
=
True
cell
=
SearchCell
(
n_nodes
,
channels_pp
,
channels_p
,
c_cur
,
reduction_p
,
reduction
)
cell
=
cell_type
(
n_nodes
,
channels_pp
,
channels_p
,
c_cur
,
reduction_p
,
reduction
)
self
.
cells
.
append
(
cell
)
c_cur_out
=
c_cur
*
n_nodes
channels_pp
,
channels_p
=
channels_p
,
c_cur_out
...
...
examples/nas
/darts/ops.py
→
src/sdk/pynni/nni/nas/pytorch
/darts/
cnn_
ops.py
View file @
3ddab980
import
torch
import
torch.nn
as
nn
PRIMITIVES
=
[
'none'
,
'max_pool_3x3'
,
'avg_pool_3x3'
,
'skip_connect'
,
# identity
'skip_connect'
,
# identity
'sep_conv_3x3'
,
'sep_conv_5x5'
,
'dil_conv_3x3'
,
'dil_conv_5x5'
,
'none'
]
OPS
=
{
'none'
:
lambda
C
,
stride
,
affine
:
Zero
(
stride
),
'avg_pool_3x3'
:
lambda
C
,
stride
,
affine
:
PoolBN
(
'avg'
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'max_pool_3x3'
:
lambda
C
,
stride
,
affine
:
PoolBN
(
'max'
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'skip_connect'
:
lambda
C
,
stride
,
affine
:
\
Identity
()
if
stride
==
1
else
FactorizedReduce
(
C
,
C
,
affine
=
affine
),
'skip_connect'
:
lambda
C
,
stride
,
affine
:
Identity
()
if
stride
==
1
else
FactorizedReduce
(
C
,
C
,
affine
=
affine
),
'sep_conv_3x3'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'sep_conv_5x5'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
5
,
stride
,
2
,
affine
=
affine
),
'sep_conv_7x7'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
7
,
stride
,
3
,
affine
=
affine
),
'dil_conv_3x3'
:
lambda
C
,
stride
,
affine
:
DilConv
(
C
,
C
,
3
,
stride
,
2
,
2
,
affine
=
affine
),
# 5x5
'dil_conv_5x5'
:
lambda
C
,
stride
,
affine
:
DilConv
(
C
,
C
,
5
,
stride
,
4
,
2
,
affine
=
affine
),
# 9x9
'dil_conv_3x3'
:
lambda
C
,
stride
,
affine
:
DilConv
(
C
,
C
,
3
,
stride
,
2
,
2
,
affine
=
affine
),
# 5x5
'dil_conv_5x5'
:
lambda
C
,
stride
,
affine
:
DilConv
(
C
,
C
,
5
,
stride
,
4
,
2
,
affine
=
affine
),
# 9x9
'conv_7x1_1x7'
:
lambda
C
,
stride
,
affine
:
FacConv
(
C
,
C
,
7
,
stride
,
3
,
affine
=
affine
)
}
...
...
@@ -60,6 +58,7 @@ class PoolBN(nn.Module):
"""
AvgPool or MaxPool - BN
"""
def
__init__
(
self
,
pool_type
,
C
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
"""
Args:
...
...
@@ -85,6 +84,7 @@ class StdConv(nn.Module):
""" Standard conv
ReLU - Conv - BN
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
...
...
@@ -101,6 +101,7 @@ class FacConv(nn.Module):
""" Factorized conv
ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_length
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
...
...
@@ -118,14 +119,14 @@ class DilConv(nn.Module):
""" (Dilated) depthwise separable conv
ReLU - (Dilated) depthwise separable - Pointwise - BN
If dilation == 2, 3x3 conv => 5x5 receptive field
5x5 conv => 9x9 receptive field
5x5 conv => 9x9 receptive field
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
C_in
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
C_in
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
)
...
...
@@ -138,6 +139,7 @@ class SepConv(nn.Module):
""" Depthwise separable conv
DilConv(dilation=1) * 2
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
...
...
@@ -172,6 +174,7 @@ class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise(stride=2).
"""
def
__init__
(
self
,
C_in
,
C_out
,
affine
=
True
):
super
().
__init__
()
self
.
relu
=
nn
.
ReLU
()
...
...
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
3ddab980
...
...
@@ -94,7 +94,8 @@ class DartsTrainer(Trainer):
with
torch
.
no_grad
():
for
step
,
(
X
,
y
)
in
enumerate
(
self
.
valid_loader
):
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
logits
=
self
.
model
(
X
)
with
self
.
mutator
.
forward_pass
():
logits
=
self
.
model
(
X
)
metrics
=
self
.
metrics
(
logits
,
y
)
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
...
...
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
View file @
3ddab980
...
...
@@ -40,7 +40,7 @@ class EnasMutator(PyTorchMutator):
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
t
ensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
requires_grad
=
False
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
T
ensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
requires_grad
=
False
)
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
()
def
after_build
(
self
,
model
):
...
...
@@ -79,7 +79,7 @@ class EnasMutator(PyTorchMutator):
self
.
_lstm_next_step
()
logit
=
self
.
soft
(
self
.
_h
[
-
1
])
if
self
.
tanh_constant
is
not
None
:
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
branch_id
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
branch_id
)
self
.
sample_log_prob
+=
log_prob
...
...
src/sdk/pynni/nni/nas/pytorch/modules.py
0 → 100644
View file @
3ddab980
from
torch
import
nn
as
nn
class
RankedModule
(
nn
.
Module
):
def
__init__
(
self
,
rank
=
None
,
reduction
=
False
):
super
(
RankedModule
,
self
).
__init__
()
self
.
rank
=
rank
self
.
reduction
=
reduction
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
3ddab980
...
...
@@ -56,9 +56,6 @@ class PyTorchMutable(nn.Module):
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
def
__repr__
(
self
):
return
"{} ({})"
.
format
(
self
.
name
,
self
.
key
)
class
MutableScope
(
PyTorchMutable
):
"""
...
...
@@ -85,6 +82,9 @@ class LayerChoice(PyTorchMutable):
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
def
__len__
(
self
):
return
self
.
length
def
forward
(
self
,
*
inputs
):
out
,
mask
=
self
.
mutator
.
on_forward
(
self
,
*
inputs
)
if
self
.
return_mask
:
...
...
@@ -116,4 +116,4 @@ class InputChoice(PyTorchMutable):
def
similar
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
\
self
.
n_candidates
==
other
.
n_candidates
and
self
.
n_selected
and
other
.
n_selected
self
.
n_candidates
==
other
.
n_candidates
and
self
.
n_selected
and
other
.
n_selected
src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py
0 → 100644
View file @
3ddab980
from
.trainer
import
PdartsTrainer
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
0 → 100644
View file @
3ddab980
import
copy
import
numpy
as
np
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
nni.nas.pytorch.darts
import
DartsMutator
from
nni.nas.pytorch.mutables
import
LayerChoice
class
PdartsMutator
(
DartsMutator
):
def
__init__
(
self
,
model
,
pdarts_epoch_index
,
pdarts_num_to_drop
,
switches
=
None
):
self
.
pdarts_epoch_index
=
pdarts_epoch_index
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
self
.
switches
=
switches
super
(
PdartsMutator
,
self
).
__init__
(
model
)
def
before_build
(
self
,
model
):
self
.
choices
=
nn
.
ParameterDict
()
if
self
.
switches
is
None
:
self
.
switches
=
{}
def
named_mutables
(
self
,
model
):
key2module
=
dict
()
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
LayerChoice
):
key2module
[
module
.
key
]
=
module
yield
name
,
module
,
True
def
drop_paths
(
self
):
for
key
in
self
.
switches
:
prob
=
F
.
softmax
(
self
.
choices
[
key
],
dim
=-
1
).
data
.
cpu
().
numpy
()
switches
=
self
.
switches
[
key
]
idxs
=
[]
for
j
in
range
(
len
(
switches
)):
if
switches
[
j
]:
idxs
.
append
(
j
)
if
self
.
pdarts_epoch_index
==
len
(
self
.
pdarts_num_to_drop
)
-
1
:
# for the last stage, drop all Zero operations
drop
=
self
.
get_min_k_no_zero
(
prob
,
idxs
,
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
])
else
:
drop
=
self
.
get_min_k
(
prob
,
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
])
for
idx
in
drop
:
switches
[
idxs
[
idx
]]
=
False
return
self
.
switches
def
on_init_layer_choice
(
self
,
mutable
:
LayerChoice
):
switches
=
self
.
switches
.
get
(
mutable
.
key
,
[
True
for
j
in
range
(
mutable
.
length
)])
for
index
in
range
(
len
(
switches
)
-
1
,
-
1
,
-
1
):
if
switches
[
index
]
==
False
:
del
(
mutable
.
choices
[
index
])
mutable
.
length
-=
1
self
.
switches
[
mutable
.
key
]
=
switches
self
.
choices
[
mutable
.
key
]
=
nn
.
Parameter
(
1.0E-3
*
torch
.
randn
(
mutable
.
length
))
def
on_calc_layer_choice_mask
(
self
,
mutable
:
LayerChoice
):
return
F
.
softmax
(
self
.
choices
[
mutable
.
key
],
dim
=-
1
)
def
get_min_k
(
self
,
input_in
,
k
):
index
=
[]
for
_
in
range
(
k
):
idx
=
np
.
argmin
(
input
)
index
.
append
(
idx
)
return
index
def
get_min_k_no_zero
(
self
,
w_in
,
idxs
,
k
):
w
=
copy
.
deepcopy
(
w_in
)
index
=
[]
if
0
in
idxs
:
zf
=
True
else
:
zf
=
False
if
zf
:
w
=
w
[
1
:]
index
.
append
(
0
)
k
=
k
-
1
for
_
in
range
(
k
):
idx
=
np
.
argmin
(
w
)
w
[
idx
]
=
1
if
zf
:
idx
=
idx
+
1
index
.
append
(
idx
)
return
index
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
0 → 100644
View file @
3ddab980
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.trainer
import
Trainer
from
.mutator
import
PdartsMutator
class
PdartsTrainer
(
Trainer
):
def
__init__
(
self
,
model_creator
,
metrics
,
num_epochs
,
dataset_train
,
dataset_valid
,
layers
=
5
,
n_nodes
=
4
,
pdarts_num_layers
=
[
0
,
6
,
12
],
pdarts_num_to_drop
=
[
3
,
2
,
2
],
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
):
self
.
model_creator
=
model_creator
self
.
layers
=
layers
self
.
n_nodes
=
n_nodes
self
.
pdarts_num_layers
=
pdarts_num_layers
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
self
.
pdarts_epoch
=
len
(
pdarts_num_to_drop
)
self
.
darts_parameters
=
{
"metrics"
:
metrics
,
"num_epochs"
:
num_epochs
,
"dataset_train"
:
dataset_train
,
"dataset_valid"
:
dataset_valid
,
"batch_size"
:
batch_size
,
"workers"
:
workers
,
"device"
:
device
,
"log_frequency"
:
log_frequency
}
def
train
(
self
):
layers
=
self
.
layers
n_nodes
=
self
.
n_nodes
switches
=
None
for
epoch
in
range
(
self
.
pdarts_epoch
):
layers
=
self
.
layers
+
self
.
pdarts_num_layers
[
epoch
]
model
,
loss
,
model_optim
,
lr_scheduler
=
self
.
model_creator
(
layers
,
n_nodes
)
mutator
=
PdartsMutator
(
model
,
epoch
,
self
.
pdarts_num_to_drop
,
switches
)
self
.
trainer
=
DartsTrainer
(
model
,
loss
=
loss
,
model_optim
=
model_optim
,
lr_scheduler
=
lr_scheduler
,
mutator
=
mutator
,
**
self
.
darts_parameters
)
print
(
"start pdrats training %s..."
%
epoch
)
self
.
trainer
.
train
()
# with open('log/parameters_%d.txt' % epoch, "w") as f:
# f.write(str(model.parameters))
switches
=
mutator
.
drop_paths
()
def
export
(
self
):
if
(
self
.
trainer
is
not
None
)
and
hasattr
(
self
.
trainer
,
"export"
):
self
.
trainer
.
export
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment