Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3ddab980
Commit
3ddab980
authored
Nov 18, 2019
by
quzha
Browse files
Merge branch 'dev-nas-refactor' of github.com:Microsoft/nni into dev-nas-refactor
parents
594924a9
d1d10de7
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
421 additions
and
25 deletions
+421
-25
.gitignore
.gitignore
+1
-0
examples/nas/.gitignore
examples/nas/.gitignore
+1
-1
examples/nas/darts/search.py
examples/nas/darts/search.py
+5
-7
examples/nas/pdarts/.gitignore
examples/nas/pdarts/.gitignore
+2
-0
examples/nas/pdarts/datasets.py
examples/nas/pdarts/datasets.py
+25
-0
examples/nas/pdarts/main.py
examples/nas/pdarts/main.py
+65
-0
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
+2
-0
src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py
src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py
+69
-0
src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py
src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py
+73
-0
src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py
src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py
+13
-10
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+2
-1
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
+2
-2
src/sdk/pynni/nni/nas/pytorch/modules.py
src/sdk/pynni/nni/nas/pytorch/modules.py
+9
-0
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+4
-4
src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py
src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
+93
-0
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
+54
-0
No files found.
.gitignore
View file @
3ddab980
...
@@ -80,6 +80,7 @@ venv.bak/
...
@@ -80,6 +80,7 @@ venv.bak/
# VSCode
# VSCode
.vscode
.vscode
.vs
# In case you place source code in ~/nni/
# In case you place source code in ~/nni/
/experiments
/experiments
examples/nas/.gitignore
View file @
3ddab980
data
data
examples/nas/darts/search.py
View file @
3ddab980
from
argparse
import
ArgumentParser
from
argparse
import
ArgumentParser
import
datasets
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
model
import
SearchCNN
import
datasets
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.darts
import
CnnNetwork
,
DartsTrainer
from
utils
import
accuracy
from
utils
import
accuracy
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"darts"
)
parser
=
ArgumentParser
(
"darts"
)
parser
.
add_argument
(
"--layers"
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
"--layers"
,
default
=
5
,
type
=
int
)
parser
.
add_argument
(
"--nodes"
,
default
=
2
,
type
=
int
)
parser
.
add_argument
(
"--nodes"
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
1
,
type
=
int
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
model
=
SearchCNN
(
3
,
16
,
10
,
args
.
layers
,
n_nodes
=
args
.
nodes
)
model
=
CnnNetwork
(
3
,
16
,
10
,
args
.
layers
,
n_nodes
=
args
.
nodes
)
criterion
=
nn
.
CrossEntropyLoss
()
criterion
=
nn
.
CrossEntropyLoss
()
optim
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
0.025
,
momentum
=
0.9
,
weight_decay
=
3.0E-4
)
optim
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
0.025
,
momentum
=
0.9
,
weight_decay
=
3.0E-4
)
...
...
examples/nas/pdarts/.gitignore
0 → 100644
View file @
3ddab980
data/*
log
examples/nas/pdarts/datasets.py
0 → 100644
View file @
3ddab980
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
def
get_dataset
(
cls
):
MEAN
=
[
0.49139968
,
0.48215827
,
0.44653124
]
STD
=
[
0.24703233
,
0.24348505
,
0.26158768
]
transf
=
[
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
()
]
normalize
=
[
transforms
.
ToTensor
(),
transforms
.
Normalize
(
MEAN
,
STD
)
]
train_transform
=
transforms
.
Compose
(
transf
+
normalize
)
valid_transform
=
transforms
.
Compose
(
normalize
)
if
cls
==
"cifar10"
:
dataset_train
=
CIFAR10
(
root
=
"./data"
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
dataset_valid
=
CIFAR10
(
root
=
"./data"
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
else
:
raise
NotImplementedError
return
dataset_train
,
dataset_valid
examples/nas/pdarts/main.py
0 → 100644
View file @
3ddab980
from
argparse
import
ArgumentParser
import
datasets
import
torch
import
torch.nn
as
nn
import
nni.nas.pytorch
as
nas
from
nni.nas.pytorch.pdarts
import
PdartsTrainer
from
nni.nas.pytorch.darts
import
CnnNetwork
,
CnnCell
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
""" Computes the precision@k for the specified values of k """
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
# one-hot case
if
target
.
ndimension
()
>
1
:
target
=
target
.
max
(
1
)[
1
]
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
dict
()
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
)
res
[
"acc{}"
.
format
(
k
)]
=
correct_k
.
mul_
(
1.0
/
batch_size
).
item
()
return
res
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"darts"
)
parser
.
add_argument
(
"--layers"
,
default
=
5
,
type
=
int
)
parser
.
add_argument
(
'--add_layers'
,
action
=
'append'
,
default
=
[
0
,
6
,
12
],
help
=
'add layers'
)
parser
.
add_argument
(
"--nodes"
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
1
,
type
=
int
)
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
def
model_creator
(
layers
,
n_nodes
):
model
=
CnnNetwork
(
3
,
16
,
10
,
layers
,
n_nodes
=
n_nodes
,
cell_type
=
CnnCell
)
loss
=
nn
.
CrossEntropyLoss
()
model_optim
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
0.025
,
momentum
=
0.9
,
weight_decay
=
3.0E-4
)
n_epochs
=
50
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
model_optim
,
n_epochs
,
eta_min
=
0.001
)
return
model
,
loss
,
model_optim
,
lr_scheduler
trainer
=
PdartsTrainer
(
model_creator
,
metrics
=
lambda
output
,
target
:
accuracy
(
output
,
target
,
topk
=
(
1
,)),
num_epochs
=
50
,
pdarts_num_layers
=
[
0
,
6
,
12
],
pdarts_num_to_drop
=
[
3
,
2
,
2
],
dataset_train
=
dataset_train
,
dataset_valid
=
dataset_valid
,
layers
=
args
.
layers
,
n_nodes
=
args
.
nodes
,
batch_size
=
args
.
batch_size
,
log_frequency
=
args
.
log_frequency
)
trainer
.
train
()
trainer
.
export
()
src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
View file @
3ddab980
from
.mutator
import
DartsMutator
from
.mutator
import
DartsMutator
from
.trainer
import
DartsTrainer
from
.trainer
import
DartsTrainer
from
.cnn_cell
import
CnnCell
from
.cnn_network
import
CnnNetwork
src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py
0 → 100644
View file @
3ddab980
import
torch
import
torch.nn
as
nn
import
nni.nas.pytorch
as
nas
from
nni.nas.pytorch.modules
import
RankedModule
from
.cnn_ops
import
OPS
,
PRIMITIVES
,
FactorizedReduce
,
StdConv
class
CnnCell
(
RankedModule
):
"""
Cell for search.
"""
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super
(
CnnCell
,
self
).
__init__
(
rank
=
1
,
reduction
=
reduction
)
self
.
n_nodes
=
n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if
reduction_p
:
self
.
preproc0
=
FactorizedReduce
(
channels_pp
,
channels
,
affine
=
False
)
else
:
self
.
preproc0
=
StdConv
(
channels_pp
,
channels
,
1
,
1
,
0
,
affine
=
False
)
self
.
preproc1
=
StdConv
(
channels_p
,
channels
,
1
,
1
,
0
,
affine
=
False
)
# generate dag
self
.
mutable_ops
=
nn
.
ModuleList
()
for
depth
in
range
(
self
.
n_nodes
):
self
.
mutable_ops
.
append
(
nn
.
ModuleList
())
for
i
in
range
(
2
+
depth
):
# include 2 input nodes
# reduction should be used only for input node
stride
=
2
if
reduction
and
i
<
2
else
1
m_ops
=
[]
for
primitive
in
PRIMITIVES
:
op
=
OPS
[
primitive
](
channels
,
stride
,
False
)
m_ops
.
append
(
op
)
op
=
nas
.
mutables
.
LayerChoice
(
m_ops
,
key
=
"r{}_d{}_i{}"
.
format
(
reduction
,
depth
,
i
))
self
.
mutable_ops
[
depth
].
append
(
op
)
def
forward
(
self
,
s0
,
s1
):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors
=
[
self
.
preproc0
(
s0
),
self
.
preproc1
(
s1
)]
for
ops
in
self
.
mutable_ops
:
assert
len
(
ops
)
==
len
(
tensors
)
cur_tensor
=
sum
(
op
(
tensor
)
for
op
,
tensor
in
zip
(
ops
,
tensors
))
tensors
.
append
(
cur_tensor
)
output
=
torch
.
cat
(
tensors
[
2
:],
dim
=
1
)
return
output
examples/nas/darts/model
.py
→
src/sdk/pynni/nni/nas/pytorch/darts/cnn_network
.py
View file @
3ddab980
import
torch
import
torch.nn
as
nn
import
ops
from
nni.nas
import
pytorch
as
nas
class
SearchCell
(
nn
.
Module
):
"""
Cell for search.
"""
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super
().
__init__
()
self
.
reduction
=
reduction
self
.
n_nodes
=
n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if
reduction_p
:
self
.
preproc0
=
ops
.
FactorizedReduce
(
channels_pp
,
channels
,
affine
=
False
)
else
:
self
.
preproc0
=
ops
.
StdConv
(
channels_pp
,
channels
,
1
,
1
,
0
,
affine
=
False
)
self
.
preproc1
=
ops
.
StdConv
(
channels_p
,
channels
,
1
,
1
,
0
,
affine
=
False
)
# generate dag
import
torch.nn
as
nn
self
.
mutable_ops
=
nn
.
ModuleList
()
for
depth
in
range
(
self
.
n_nodes
):
self
.
mutable_ops
.
append
(
nn
.
ModuleList
())
for
i
in
range
(
2
+
depth
):
# include 2 input nodes
# reduction should be used only for input node
stride
=
2
if
reduction
and
i
<
2
else
1
op
=
nas
.
mutables
.
LayerChoice
([
ops
.
PoolBN
(
'max'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
PoolBN
(
'avg'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
Identity
()
if
stride
==
1
else
ops
.
FactorizedReduce
(
channels
,
channels
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
),
ops
.
Zero
(
stride
)],
key
=
"r{}_d{}_i{}"
.
format
(
reduction
,
depth
,
i
))
self
.
mutable_ops
[
depth
].
append
(
op
)
def
forward
(
self
,
s0
,
s1
):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors
=
[
self
.
preproc0
(
s0
),
self
.
preproc1
(
s1
)]
for
ops
in
self
.
mutable_ops
:
assert
len
(
ops
)
==
len
(
tensors
)
cur_tensor
=
sum
(
op
(
tensor
)
for
op
,
tensor
in
zip
(
ops
,
tensors
))
tensors
.
append
(
cur_tensor
)
output
=
torch
.
cat
(
tensors
[
2
:],
dim
=
1
)
from
.cnn_cell
import
CnnCell
return
output
class
SearchCNN
(
nn
.
Module
):
class
CnnNetwork
(
nn
.
Module
):
"""
"""
Search CNN model
Search CNN model
"""
"""
def
__init__
(
self
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
stem_multiplier
=
3
):
def
__init__
(
self
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
stem_multiplier
=
3
,
cell_type
=
CnnCell
):
"""
"""
Initializing a search channelsNN.
Initializing a search channelsNN.
...
@@ -121,7 +53,7 @@ class SearchCNN(nn.Module):
...
@@ -121,7 +53,7 @@ class SearchCNN(nn.Module):
c_cur
*=
2
c_cur
*=
2
reduction
=
True
reduction
=
True
cell
=
SearchCell
(
n_nodes
,
channels_pp
,
channels_p
,
c_cur
,
reduction_p
,
reduction
)
cell
=
cell_type
(
n_nodes
,
channels_pp
,
channels_p
,
c_cur
,
reduction_p
,
reduction
)
self
.
cells
.
append
(
cell
)
self
.
cells
.
append
(
cell
)
c_cur_out
=
c_cur
*
n_nodes
c_cur_out
=
c_cur
*
n_nodes
channels_pp
,
channels_p
=
channels_p
,
c_cur_out
channels_pp
,
channels_p
=
channels_p
,
c_cur_out
...
...
examples/nas
/darts/ops.py
→
src/sdk/pynni/nni/nas/pytorch
/darts/
cnn_
ops.py
View file @
3ddab980
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
PRIMITIVES
=
[
PRIMITIVES
=
[
'none'
,
'max_pool_3x3'
,
'max_pool_3x3'
,
'avg_pool_3x3'
,
'avg_pool_3x3'
,
'skip_connect'
,
# identity
'skip_connect'
,
# identity
'sep_conv_3x3'
,
'sep_conv_3x3'
,
'sep_conv_5x5'
,
'sep_conv_5x5'
,
'dil_conv_3x3'
,
'dil_conv_3x3'
,
'dil_conv_5x5'
,
'dil_conv_5x5'
,
'none'
]
]
OPS
=
{
OPS
=
{
'none'
:
lambda
C
,
stride
,
affine
:
Zero
(
stride
),
'none'
:
lambda
C
,
stride
,
affine
:
Zero
(
stride
),
'avg_pool_3x3'
:
lambda
C
,
stride
,
affine
:
PoolBN
(
'avg'
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'avg_pool_3x3'
:
lambda
C
,
stride
,
affine
:
PoolBN
(
'avg'
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'max_pool_3x3'
:
lambda
C
,
stride
,
affine
:
PoolBN
(
'max'
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'max_pool_3x3'
:
lambda
C
,
stride
,
affine
:
PoolBN
(
'max'
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'skip_connect'
:
lambda
C
,
stride
,
affine
:
\
'skip_connect'
:
lambda
C
,
stride
,
affine
:
Identity
()
if
stride
==
1
else
FactorizedReduce
(
C
,
C
,
affine
=
affine
),
Identity
()
if
stride
==
1
else
FactorizedReduce
(
C
,
C
,
affine
=
affine
),
'sep_conv_3x3'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'sep_conv_3x3'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
3
,
stride
,
1
,
affine
=
affine
),
'sep_conv_5x5'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
5
,
stride
,
2
,
affine
=
affine
),
'sep_conv_5x5'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
5
,
stride
,
2
,
affine
=
affine
),
'sep_conv_7x7'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
7
,
stride
,
3
,
affine
=
affine
),
'sep_conv_7x7'
:
lambda
C
,
stride
,
affine
:
SepConv
(
C
,
C
,
7
,
stride
,
3
,
affine
=
affine
),
'dil_conv_3x3'
:
lambda
C
,
stride
,
affine
:
DilConv
(
C
,
C
,
3
,
stride
,
2
,
2
,
affine
=
affine
),
# 5x5
'dil_conv_3x3'
:
lambda
C
,
stride
,
affine
:
DilConv
(
C
,
C
,
3
,
stride
,
2
,
2
,
affine
=
affine
),
# 5x5
'dil_conv_5x5'
:
lambda
C
,
stride
,
affine
:
DilConv
(
C
,
C
,
5
,
stride
,
4
,
2
,
affine
=
affine
),
# 9x9
'dil_conv_5x5'
:
lambda
C
,
stride
,
affine
:
DilConv
(
C
,
C
,
5
,
stride
,
4
,
2
,
affine
=
affine
),
# 9x9
'conv_7x1_1x7'
:
lambda
C
,
stride
,
affine
:
FacConv
(
C
,
C
,
7
,
stride
,
3
,
affine
=
affine
)
'conv_7x1_1x7'
:
lambda
C
,
stride
,
affine
:
FacConv
(
C
,
C
,
7
,
stride
,
3
,
affine
=
affine
)
}
}
...
@@ -60,6 +58,7 @@ class PoolBN(nn.Module):
...
@@ -60,6 +58,7 @@ class PoolBN(nn.Module):
"""
"""
AvgPool or MaxPool - BN
AvgPool or MaxPool - BN
"""
"""
def
__init__
(
self
,
pool_type
,
C
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
pool_type
,
C
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
"""
"""
Args:
Args:
...
@@ -85,6 +84,7 @@ class StdConv(nn.Module):
...
@@ -85,6 +84,7 @@ class StdConv(nn.Module):
""" Standard conv
""" Standard conv
ReLU - Conv - BN
ReLU - Conv - BN
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
...
@@ -101,6 +101,7 @@ class FacConv(nn.Module):
...
@@ -101,6 +101,7 @@ class FacConv(nn.Module):
""" Factorized conv
""" Factorized conv
ReLU - Conv(Kx1) - Conv(1xK) - BN
ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_length
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_length
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
...
@@ -118,14 +119,14 @@ class DilConv(nn.Module):
...
@@ -118,14 +119,14 @@ class DilConv(nn.Module):
""" (Dilated) depthwise separable conv
""" (Dilated) depthwise separable conv
ReLU - (Dilated) depthwise separable - Pointwise - BN
ReLU - (Dilated) depthwise separable - Pointwise - BN
If dilation == 2, 3x3 conv => 5x5 receptive field
If dilation == 2, 3x3 conv => 5x5 receptive field
5x5 conv => 9x9 receptive field
5x5 conv => 9x9 receptive field
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
,
affine
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
C_in
,
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
C_in
,
bias
=
False
),
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
)
)
...
@@ -138,6 +139,7 @@ class SepConv(nn.Module):
...
@@ -138,6 +139,7 @@ class SepConv(nn.Module):
""" Depthwise separable conv
""" Depthwise separable conv
DilConv(dilation=1) * 2
DilConv(dilation=1) * 2
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
...
@@ -172,6 +174,7 @@ class FactorizedReduce(nn.Module):
...
@@ -172,6 +174,7 @@ class FactorizedReduce(nn.Module):
"""
"""
Reduce feature map size by factorized pointwise(stride=2).
Reduce feature map size by factorized pointwise(stride=2).
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
affine
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
...
...
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
3ddab980
...
@@ -94,7 +94,8 @@ class DartsTrainer(Trainer):
...
@@ -94,7 +94,8 @@ class DartsTrainer(Trainer):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
step
,
(
X
,
y
)
in
enumerate
(
self
.
valid_loader
):
for
step
,
(
X
,
y
)
in
enumerate
(
self
.
valid_loader
):
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
logits
=
self
.
model
(
X
)
with
self
.
mutator
.
forward_pass
():
logits
=
self
.
model
(
X
)
metrics
=
self
.
metrics
(
logits
,
y
)
metrics
=
self
.
metrics
(
logits
,
y
)
meters
.
update
(
metrics
)
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
...
...
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
View file @
3ddab980
...
@@ -40,7 +40,7 @@ class EnasMutator(PyTorchMutator):
...
@@ -40,7 +40,7 @@ class EnasMutator(PyTorchMutator):
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
t
ensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
requires_grad
=
False
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
T
ensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
requires_grad
=
False
)
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
()
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
()
def
after_build
(
self
,
model
):
def
after_build
(
self
,
model
):
...
@@ -79,7 +79,7 @@ class EnasMutator(PyTorchMutator):
...
@@ -79,7 +79,7 @@ class EnasMutator(PyTorchMutator):
self
.
_lstm_next_step
()
self
.
_lstm_next_step
()
logit
=
self
.
soft
(
self
.
_h
[
-
1
])
logit
=
self
.
soft
(
self
.
_h
[
-
1
])
if
self
.
tanh_constant
is
not
None
:
if
self
.
tanh_constant
is
not
None
:
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
branch_id
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
branch_id
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
branch_id
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
branch_id
)
self
.
sample_log_prob
+=
log_prob
self
.
sample_log_prob
+=
log_prob
...
...
src/sdk/pynni/nni/nas/pytorch/modules.py
0 → 100644
View file @
3ddab980
from
torch
import
nn
as
nn
class
RankedModule
(
nn
.
Module
):
def
__init__
(
self
,
rank
=
None
,
reduction
=
False
):
super
(
RankedModule
,
self
).
__init__
()
self
.
rank
=
rank
self
.
reduction
=
reduction
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
3ddab980
...
@@ -56,9 +56,6 @@ class PyTorchMutable(nn.Module):
...
@@ -56,9 +56,6 @@ class PyTorchMutable(nn.Module):
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
def
__repr__
(
self
):
return
"{} ({})"
.
format
(
self
.
name
,
self
.
key
)
class
MutableScope
(
PyTorchMutable
):
class
MutableScope
(
PyTorchMutable
):
"""
"""
...
@@ -85,6 +82,9 @@ class LayerChoice(PyTorchMutable):
...
@@ -85,6 +82,9 @@ class LayerChoice(PyTorchMutable):
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
self
.
return_mask
=
return_mask
def
__len__
(
self
):
return
self
.
length
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
out
,
mask
=
self
.
mutator
.
on_forward
(
self
,
*
inputs
)
out
,
mask
=
self
.
mutator
.
on_forward
(
self
,
*
inputs
)
if
self
.
return_mask
:
if
self
.
return_mask
:
...
@@ -116,4 +116,4 @@ class InputChoice(PyTorchMutable):
...
@@ -116,4 +116,4 @@ class InputChoice(PyTorchMutable):
def
similar
(
self
,
other
):
def
similar
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
\
return
type
(
self
)
==
type
(
other
)
and
\
self
.
n_candidates
==
other
.
n_candidates
and
self
.
n_selected
and
other
.
n_selected
self
.
n_candidates
==
other
.
n_candidates
and
self
.
n_selected
and
other
.
n_selected
src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py
0 → 100644
View file @
3ddab980
from
.trainer
import
PdartsTrainer
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
0 → 100644
View file @
3ddab980
import
copy
import
numpy
as
np
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
nni.nas.pytorch.darts
import
DartsMutator
from
nni.nas.pytorch.mutables
import
LayerChoice
class
PdartsMutator
(
DartsMutator
):
def
__init__
(
self
,
model
,
pdarts_epoch_index
,
pdarts_num_to_drop
,
switches
=
None
):
self
.
pdarts_epoch_index
=
pdarts_epoch_index
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
self
.
switches
=
switches
super
(
PdartsMutator
,
self
).
__init__
(
model
)
def
before_build
(
self
,
model
):
self
.
choices
=
nn
.
ParameterDict
()
if
self
.
switches
is
None
:
self
.
switches
=
{}
def
named_mutables
(
self
,
model
):
key2module
=
dict
()
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
LayerChoice
):
key2module
[
module
.
key
]
=
module
yield
name
,
module
,
True
def
drop_paths
(
self
):
for
key
in
self
.
switches
:
prob
=
F
.
softmax
(
self
.
choices
[
key
],
dim
=-
1
).
data
.
cpu
().
numpy
()
switches
=
self
.
switches
[
key
]
idxs
=
[]
for
j
in
range
(
len
(
switches
)):
if
switches
[
j
]:
idxs
.
append
(
j
)
if
self
.
pdarts_epoch_index
==
len
(
self
.
pdarts_num_to_drop
)
-
1
:
# for the last stage, drop all Zero operations
drop
=
self
.
get_min_k_no_zero
(
prob
,
idxs
,
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
])
else
:
drop
=
self
.
get_min_k
(
prob
,
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
])
for
idx
in
drop
:
switches
[
idxs
[
idx
]]
=
False
return
self
.
switches
def
on_init_layer_choice
(
self
,
mutable
:
LayerChoice
):
switches
=
self
.
switches
.
get
(
mutable
.
key
,
[
True
for
j
in
range
(
mutable
.
length
)])
for
index
in
range
(
len
(
switches
)
-
1
,
-
1
,
-
1
):
if
switches
[
index
]
==
False
:
del
(
mutable
.
choices
[
index
])
mutable
.
length
-=
1
self
.
switches
[
mutable
.
key
]
=
switches
self
.
choices
[
mutable
.
key
]
=
nn
.
Parameter
(
1.0E-3
*
torch
.
randn
(
mutable
.
length
))
def
on_calc_layer_choice_mask
(
self
,
mutable
:
LayerChoice
):
return
F
.
softmax
(
self
.
choices
[
mutable
.
key
],
dim
=-
1
)
def
get_min_k
(
self
,
input_in
,
k
):
index
=
[]
for
_
in
range
(
k
):
idx
=
np
.
argmin
(
input
)
index
.
append
(
idx
)
return
index
def
get_min_k_no_zero
(
self
,
w_in
,
idxs
,
k
):
w
=
copy
.
deepcopy
(
w_in
)
index
=
[]
if
0
in
idxs
:
zf
=
True
else
:
zf
=
False
if
zf
:
w
=
w
[
1
:]
index
.
append
(
0
)
k
=
k
-
1
for
_
in
range
(
k
):
idx
=
np
.
argmin
(
w
)
w
[
idx
]
=
1
if
zf
:
idx
=
idx
+
1
index
.
append
(
idx
)
return
index
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
0 → 100644
View file @
3ddab980
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.trainer
import
Trainer
from
.mutator
import
PdartsMutator
class
PdartsTrainer
(
Trainer
):
def
__init__
(
self
,
model_creator
,
metrics
,
num_epochs
,
dataset_train
,
dataset_valid
,
layers
=
5
,
n_nodes
=
4
,
pdarts_num_layers
=
[
0
,
6
,
12
],
pdarts_num_to_drop
=
[
3
,
2
,
2
],
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
):
self
.
model_creator
=
model_creator
self
.
layers
=
layers
self
.
n_nodes
=
n_nodes
self
.
pdarts_num_layers
=
pdarts_num_layers
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
self
.
pdarts_epoch
=
len
(
pdarts_num_to_drop
)
self
.
darts_parameters
=
{
"metrics"
:
metrics
,
"num_epochs"
:
num_epochs
,
"dataset_train"
:
dataset_train
,
"dataset_valid"
:
dataset_valid
,
"batch_size"
:
batch_size
,
"workers"
:
workers
,
"device"
:
device
,
"log_frequency"
:
log_frequency
}
def
train
(
self
):
layers
=
self
.
layers
n_nodes
=
self
.
n_nodes
switches
=
None
for
epoch
in
range
(
self
.
pdarts_epoch
):
layers
=
self
.
layers
+
self
.
pdarts_num_layers
[
epoch
]
model
,
loss
,
model_optim
,
lr_scheduler
=
self
.
model_creator
(
layers
,
n_nodes
)
mutator
=
PdartsMutator
(
model
,
epoch
,
self
.
pdarts_num_to_drop
,
switches
)
self
.
trainer
=
DartsTrainer
(
model
,
loss
=
loss
,
model_optim
=
model_optim
,
lr_scheduler
=
lr_scheduler
,
mutator
=
mutator
,
**
self
.
darts_parameters
)
print
(
"start pdrats training %s..."
%
epoch
)
self
.
trainer
.
train
()
# with open('log/parameters_%d.txt' % epoch, "w") as f:
# f.write(str(model.parameters))
switches
=
mutator
.
drop_paths
()
def
export
(
self
):
if
(
self
.
trainer
is
not
None
)
and
hasattr
(
self
.
trainer
,
"export"
):
self
.
trainer
.
export
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment