Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
73b2221b
Commit
73b2221b
authored
Nov 22, 2019
by
Yuge Zhang
Committed by
QuanluZhang
Nov 22, 2019
Browse files
Update DARTS trainer and fix docstring issues (#1772)
parent
6d6f9524
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
236 additions
and
144 deletions
+236
-144
examples/nas/.gitignore
examples/nas/.gitignore
+1
-0
examples/nas/darts/model.py
examples/nas/darts/model.py
+2
-1
examples/nas/darts/ops.py
examples/nas/darts/ops.py
+20
-19
examples/nas/darts/retrain.py
examples/nas/darts/retrain.py
+16
-6
examples/nas/darts/search.py
examples/nas/darts/search.py
+4
-3
examples/nas/enas/macro.py
examples/nas/enas/macro.py
+1
-1
examples/nas/enas/search.py
examples/nas/enas/search.py
+2
-2
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
+12
-15
src/sdk/pynni/nni/nas/pytorch/callbacks.py
src/sdk/pynni/nni/nas/pytorch/callbacks.py
+1
-1
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
+2
-2
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+82
-55
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
+2
-1
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
+4
-6
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+7
-6
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+14
-10
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+15
-9
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
+1
-1
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
+2
-2
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+34
-0
src/sdk/pynni/nni/nas/pytorch/utils.py
src/sdk/pynni/nni/nas/pytorch/utils.py
+14
-4
No files found.
examples/nas/.gitignore
View file @
73b2221b
data
data
checkpoints
checkpoints
runs
examples/nas/darts/model.py
View file @
73b2221b
...
@@ -48,7 +48,7 @@ class Node(nn.Module):
...
@@ -48,7 +48,7 @@ class Node(nn.Module):
ops
.
SepConv
(
channels
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
)
,
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
)
],
],
key
=
choice_keys
[
-
1
]))
key
=
choice_keys
[
-
1
]))
self
.
drop_path
=
ops
.
DropPath_
()
self
.
drop_path
=
ops
.
DropPath_
()
...
@@ -57,6 +57,7 @@ class Node(nn.Module):
...
@@ -57,6 +57,7 @@ class Node(nn.Module):
def
forward
(
self
,
prev_nodes
):
def
forward
(
self
,
prev_nodes
):
assert
len
(
self
.
ops
)
==
len
(
prev_nodes
)
assert
len
(
self
.
ops
)
==
len
(
prev_nodes
)
out
=
[
op
(
node
)
for
op
,
node
in
zip
(
self
.
ops
,
prev_nodes
)]
out
=
[
op
(
node
)
for
op
,
node
in
zip
(
self
.
ops
,
prev_nodes
)]
out
=
[
self
.
drop_path
(
o
)
if
o
is
not
None
else
None
for
o
in
out
]
return
self
.
input_switch
(
out
)
return
self
.
input_switch
(
out
)
...
...
examples/nas/darts/ops.py
View file @
73b2221b
...
@@ -4,9 +4,13 @@ import torch.nn as nn
...
@@ -4,9 +4,13 @@ import torch.nn as nn
class
DropPath_
(
nn
.
Module
):
class
DropPath_
(
nn
.
Module
):
def
__init__
(
self
,
p
=
0.
):
def
__init__
(
self
,
p
=
0.
):
""" [!] DropPath is inplace module
"""
Args:
DropPath is inplace module.
p: probability of an path to be zeroed.
Parameters
----------
p : float
Probability of an path to be zeroed.
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
p
=
p
self
.
p
=
p
...
@@ -26,13 +30,9 @@ class DropPath_(nn.Module):
...
@@ -26,13 +30,9 @@ class DropPath_(nn.Module):
class
PoolBN
(
nn
.
Module
):
class
PoolBN
(
nn
.
Module
):
"""
"""
AvgPool or MaxPool
- BN
AvgPool or MaxPool
with BN. `pool_type` must be `max` or `avg`.
"""
"""
def
__init__
(
self
,
pool_type
,
C
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
pool_type
,
C
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
"""
Args:
pool_type: 'max' or 'avg'
"""
super
().
__init__
()
super
().
__init__
()
if
pool_type
.
lower
()
==
'max'
:
if
pool_type
.
lower
()
==
'max'
:
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
,
stride
,
padding
)
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
,
stride
,
padding
)
...
@@ -50,8 +50,8 @@ class PoolBN(nn.Module):
...
@@ -50,8 +50,8 @@ class PoolBN(nn.Module):
class
StdConv
(
nn
.
Module
):
class
StdConv
(
nn
.
Module
):
"""
Standard conv
"""
ReLU - Conv - BN
Standard conv:
ReLU - Conv - BN
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
super
().
__init__
()
...
@@ -66,8 +66,8 @@ class StdConv(nn.Module):
...
@@ -66,8 +66,8 @@ class StdConv(nn.Module):
class
FacConv
(
nn
.
Module
):
class
FacConv
(
nn
.
Module
):
"""
Factorized conv
"""
ReLU - Conv(Kx1) - Conv(1xK) - BN
Factorized conv:
ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_length
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_length
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
super
().
__init__
()
...
@@ -83,10 +83,10 @@ class FacConv(nn.Module):
...
@@ -83,10 +83,10 @@ class FacConv(nn.Module):
class
DilConv
(
nn
.
Module
):
class
DilConv
(
nn
.
Module
):
"""
(Dilated) depthwise separable conv
"""
ReLU -
(Dilated) depthwise separable
- Pointwise - BN
(Dilated) depthwise separable
conv.
If dilation == 2, 3x3 conv => 5x5 receptive field
ReLU - (Dilated) depthwise separable - Pointwise - BN.
5x5 conv => 9x9 receptive field
If dilation == 2, 3x3 conv => 5x5 receptive field,
5x5 conv => 9x9 receptive field
.
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
,
affine
=
True
):
super
().
__init__
()
super
().
__init__
()
...
@@ -103,8 +103,9 @@ class DilConv(nn.Module):
...
@@ -103,8 +103,9 @@ class DilConv(nn.Module):
class
SepConv
(
nn
.
Module
):
class
SepConv
(
nn
.
Module
):
""" Depthwise separable conv
"""
DilConv(dilation=1) * 2
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
super
().
__init__
()
...
@@ -119,7 +120,7 @@ class SepConv(nn.Module):
...
@@ -119,7 +120,7 @@ class SepConv(nn.Module):
class
FactorizedReduce
(
nn
.
Module
):
class
FactorizedReduce
(
nn
.
Module
):
"""
"""
Reduce feature map size by factorized pointwise(stride=2).
Reduce feature map size by factorized pointwise
(stride=2).
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
affine
=
True
):
super
().
__init__
()
super
().
__init__
()
...
...
examples/nas/darts/retrain.py
View file @
73b2221b
...
@@ -4,12 +4,13 @@ from argparse import ArgumentParser
...
@@ -4,12 +4,13 @@ from argparse import ArgumentParser
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
from
nni.nas.pytorch.utils
import
AverageMeter
from
torch.utils.tensorboard
import
SummaryWriter
import
datasets
import
datasets
import
utils
import
utils
from
model
import
CNN
from
model
import
CNN
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
from
nni.nas.pytorch.utils
import
AverageMeter
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
()
...
@@ -23,6 +24,7 @@ logger.setLevel(logging.INFO)
...
@@ -23,6 +24,7 @@ logger.setLevel(logging.INFO)
logger
.
addHandler
(
std_out_info
)
logger
.
addHandler
(
std_out_info
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
writer
=
SummaryWriter
()
def
train
(
config
,
train_loader
,
model
,
optimizer
,
criterion
,
epoch
):
def
train
(
config
,
train_loader
,
model
,
optimizer
,
criterion
,
epoch
):
...
@@ -33,6 +35,7 @@ def train(config, train_loader, model, optimizer, criterion, epoch):
...
@@ -33,6 +35,7 @@ def train(config, train_loader, model, optimizer, criterion, epoch):
cur_step
=
epoch
*
len
(
train_loader
)
cur_step
=
epoch
*
len
(
train_loader
)
cur_lr
=
optimizer
.
param_groups
[
0
][
'lr'
]
cur_lr
=
optimizer
.
param_groups
[
0
][
'lr'
]
logger
.
info
(
"Epoch %d LR %.6f"
,
epoch
,
cur_lr
)
logger
.
info
(
"Epoch %d LR %.6f"
,
epoch
,
cur_lr
)
writer
.
add_scalar
(
"lr"
,
cur_lr
,
global_step
=
cur_step
)
model
.
train
()
model
.
train
()
...
@@ -54,6 +57,9 @@ def train(config, train_loader, model, optimizer, criterion, epoch):
...
@@ -54,6 +57,9 @@ def train(config, train_loader, model, optimizer, criterion, epoch):
losses
.
update
(
loss
.
item
(),
bs
)
losses
.
update
(
loss
.
item
(),
bs
)
top1
.
update
(
accuracy
[
"acc1"
],
bs
)
top1
.
update
(
accuracy
[
"acc1"
],
bs
)
top5
.
update
(
accuracy
[
"acc5"
],
bs
)
top5
.
update
(
accuracy
[
"acc5"
],
bs
)
writer
.
add_scalar
(
"loss/train"
,
loss
.
item
(),
global_step
=
cur_step
)
writer
.
add_scalar
(
"acc1/train"
,
accuracy
[
"acc1"
],
global_step
=
cur_step
)
writer
.
add_scalar
(
"acc5/train"
,
accuracy
[
"acc5"
],
global_step
=
cur_step
)
if
step
%
config
.
log_frequency
==
0
or
step
==
len
(
train_loader
)
-
1
:
if
step
%
config
.
log_frequency
==
0
or
step
==
len
(
train_loader
)
-
1
:
logger
.
info
(
logger
.
info
(
...
@@ -77,15 +83,15 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step):
...
@@ -77,15 +83,15 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
step
,
(
X
,
y
)
in
enumerate
(
valid_loader
):
for
step
,
(
X
,
y
)
in
enumerate
(
valid_loader
):
X
,
y
=
X
.
to
(
device
,
non_blocking
=
True
),
y
.
to
(
device
,
non_blocking
=
True
)
X
,
y
=
X
.
to
(
device
,
non_blocking
=
True
),
y
.
to
(
device
,
non_blocking
=
True
)
N
=
X
.
size
(
0
)
bs
=
X
.
size
(
0
)
logits
=
model
(
X
)
logits
=
model
(
X
)
loss
=
criterion
(
logits
,
y
)
loss
=
criterion
(
logits
,
y
)
accuracy
=
utils
.
accuracy
(
logits
,
y
,
topk
=
(
1
,
5
))
accuracy
=
utils
.
accuracy
(
logits
,
y
,
topk
=
(
1
,
5
))
losses
.
update
(
loss
.
item
(),
N
)
losses
.
update
(
loss
.
item
(),
bs
)
top1
.
update
(
accuracy
[
"acc1"
],
N
)
top1
.
update
(
accuracy
[
"acc1"
],
bs
)
top5
.
update
(
accuracy
[
"acc5"
],
N
)
top5
.
update
(
accuracy
[
"acc5"
],
bs
)
if
step
%
config
.
log_frequency
==
0
or
step
==
len
(
valid_loader
)
-
1
:
if
step
%
config
.
log_frequency
==
0
or
step
==
len
(
valid_loader
)
-
1
:
logger
.
info
(
logger
.
info
(
...
@@ -94,6 +100,10 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step):
...
@@ -94,6 +100,10 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step):
epoch
+
1
,
config
.
epochs
,
step
,
len
(
valid_loader
)
-
1
,
losses
=
losses
,
epoch
+
1
,
config
.
epochs
,
step
,
len
(
valid_loader
)
-
1
,
losses
=
losses
,
top1
=
top1
,
top5
=
top5
))
top1
=
top1
,
top5
=
top5
))
writer
.
add_scalar
(
"loss/test"
,
losses
.
avg
,
global_step
=
cur_step
)
writer
.
add_scalar
(
"acc1/test"
,
top1
.
avg
,
global_step
=
cur_step
)
writer
.
add_scalar
(
"acc5/test"
,
top5
.
avg
,
global_step
=
cur_step
)
logger
.
info
(
"Valid: [{:3d}/{}] Final Prec@1 {:.4%}"
.
format
(
epoch
+
1
,
config
.
epochs
,
top1
.
avg
))
logger
.
info
(
"Valid: [{:3d}/{}] Final Prec@1 {:.4%}"
.
format
(
epoch
+
1
,
config
.
epochs
,
top1
.
avg
))
return
top1
.
avg
return
top1
.
avg
...
...
examples/nas/darts/search.py
View file @
73b2221b
...
@@ -7,8 +7,7 @@ import torch.nn as nn
...
@@ -7,8 +7,7 @@ import torch.nn as nn
import
datasets
import
datasets
from
model
import
CNN
from
model
import
CNN
from
nni.nas.pytorch.callbacks
import
(
ArchitectureCheckpoint
,
from
nni.nas.pytorch.callbacks
import
ArchitectureCheckpoint
,
LRSchedulerCallback
LearningRateScheduler
)
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.darts
import
DartsTrainer
from
utils
import
accuracy
from
utils
import
accuracy
...
@@ -29,6 +28,7 @@ if __name__ == "__main__":
...
@@ -29,6 +28,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--batch-size"
,
default
=
64
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
64
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--epochs"
,
default
=
50
,
type
=
int
)
parser
.
add_argument
(
"--epochs"
,
default
=
50
,
type
=
int
)
parser
.
add_argument
(
"--unrolled"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
...
@@ -48,5 +48,6 @@ if __name__ == "__main__":
...
@@ -48,5 +48,6 @@ if __name__ == "__main__":
dataset_valid
=
dataset_valid
,
dataset_valid
=
dataset_valid
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
log_frequency
=
args
.
log_frequency
,
log_frequency
=
args
.
log_frequency
,
callbacks
=
[
LearningRateScheduler
(
lr_scheduler
),
ArchitectureCheckpoint
(
"./checkpoints"
)])
unrolled
=
args
.
unrolled
,
callbacks
=
[
LRSchedulerCallback
(
lr_scheduler
),
ArchitectureCheckpoint
(
"./checkpoints"
)])
trainer
.
train
()
trainer
.
train
()
examples/nas/enas/macro.py
View file @
73b2221b
...
@@ -19,7 +19,7 @@ class ENASLayer(mutables.MutableScope):
...
@@ -19,7 +19,7 @@ class ENASLayer(mutables.MutableScope):
PoolBranch
(
'max'
,
in_filters
,
out_filters
,
3
,
1
,
1
)
PoolBranch
(
'max'
,
in_filters
,
out_filters
,
3
,
1
,
1
)
])
])
if
len
(
prev_labels
)
>
0
:
if
len
(
prev_labels
)
>
0
:
self
.
skipconnect
=
mutables
.
InputChoice
(
choose_from
=
prev_labels
,
n_chosen
=
None
,
reduction
=
"sum"
)
self
.
skipconnect
=
mutables
.
InputChoice
(
choose_from
=
prev_labels
,
n_chosen
=
None
)
else
:
else
:
self
.
skipconnect
=
None
self
.
skipconnect
=
None
self
.
batch_norm
=
nn
.
BatchNorm2d
(
out_filters
,
affine
=
False
)
self
.
batch_norm
=
nn
.
BatchNorm2d
(
out_filters
,
affine
=
False
)
...
...
examples/nas/enas/search.py
View file @
73b2221b
...
@@ -9,7 +9,7 @@ import datasets
...
@@ -9,7 +9,7 @@ import datasets
from
macro
import
GeneralNetwork
from
macro
import
GeneralNetwork
from
micro
import
MicroNetwork
from
micro
import
MicroNetwork
from
nni.nas.pytorch
import
enas
from
nni.nas.pytorch
import
enas
from
nni.nas.pytorch.callbacks
import
L
earningRate
Scheduler
,
ArchitectureCheckpoint
from
nni.nas.pytorch.callbacks
import
L
R
Scheduler
Callback
,
ArchitectureCheckpoint
from
utils
import
accuracy
,
reward_accuracy
from
utils
import
accuracy
,
reward_accuracy
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
()
...
@@ -51,7 +51,7 @@ if __name__ == "__main__":
...
@@ -51,7 +51,7 @@ if __name__ == "__main__":
metrics
=
accuracy
,
metrics
=
accuracy
,
reward_function
=
reward_accuracy
,
reward_function
=
reward_accuracy
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
callbacks
=
[
L
earningRate
Scheduler
(
lr_scheduler
),
ArchitectureCheckpoint
(
"./checkpoints"
)],
callbacks
=
[
L
R
Scheduler
Callback
(
lr_scheduler
),
ArchitectureCheckpoint
(
"./checkpoints"
)],
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
num_epochs
=
num_epochs
,
num_epochs
=
num_epochs
,
dataset_train
=
dataset_train
,
dataset_train
=
dataset_train
,
...
...
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
View file @
73b2221b
...
@@ -51,21 +51,22 @@ class BaseMutator(nn.Module):
...
@@ -51,21 +51,22 @@ class BaseMutator(nn.Module):
def
mutables
(
self
):
def
mutables
(
self
):
return
self
.
_structured_mutables
return
self
.
_structured_mutables
@
property
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
raise
RuntimeError
(
"Forward is undefined for mutators."
)
raise
RuntimeError
(
"Forward is undefined for mutators."
)
def
__setattr__
(
self
,
name
,
value
):
if
name
==
"model"
:
raise
AttributeError
(
"Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include you network, as it will include all parameters in model into the mutator."
)
return
super
().
__setattr__
(
name
,
value
)
def
enter_mutable_scope
(
self
,
mutable_scope
):
def
enter_mutable_scope
(
self
,
mutable_scope
):
"""
"""
Callback when forward of a MutableScope is entered.
Callback when forward of a MutableScope is entered.
Parameters
Parameters
----------
----------
mutable_scope: MutableScope
mutable_scope : MutableScope
Returns
-------
None
"""
"""
pass
pass
...
@@ -75,11 +76,7 @@ class BaseMutator(nn.Module):
...
@@ -75,11 +76,7 @@ class BaseMutator(nn.Module):
Parameters
Parameters
----------
----------
mutable_scope: MutableScope
mutable_scope : MutableScope
Returns
-------
None
"""
"""
pass
pass
...
@@ -89,8 +86,8 @@ class BaseMutator(nn.Module):
...
@@ -89,8 +86,8 @@ class BaseMutator(nn.Module):
Parameters
Parameters
----------
----------
mutable: LayerChoice
mutable
: LayerChoice
inputs: list of torch.Tensor
inputs
: list of torch.Tensor
Returns
Returns
-------
-------
...
@@ -105,8 +102,8 @@ class BaseMutator(nn.Module):
...
@@ -105,8 +102,8 @@ class BaseMutator(nn.Module):
Parameters
Parameters
----------
----------
mutable: InputChoice
mutable
: InputChoice
tensor_list: list of torch.Tensor
tensor_list
: list of torch.Tensor
Returns
Returns
-------
-------
...
...
src/sdk/pynni/nni/nas/pytorch/callbacks.py
View file @
73b2221b
...
@@ -29,7 +29,7 @@ class Callback:
...
@@ -29,7 +29,7 @@ class Callback:
pass
pass
class
L
earningRate
Scheduler
(
Callback
):
class
L
R
Scheduler
Callback
(
Callback
):
def
__init__
(
self
,
scheduler
,
mode
=
"epoch"
):
def
__init__
(
self
,
scheduler
,
mode
=
"epoch"
):
super
().
__init__
()
super
().
__init__
()
assert
mode
==
"epoch"
assert
mode
==
"epoch"
...
...
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
View file @
73b2221b
import
torch
import
torch
from
torch
import
nn
as
nn
import
torch.
nn
as
nn
from
torch.nn
import
functional
as
F
import
torch.nn
.
functional
as
F
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
...
...
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
73b2221b
...
@@ -2,27 +2,27 @@ import copy
...
@@ -2,27 +2,27 @@ import copy
import
logging
import
logging
import
torch
import
torch
from
torch
import
nn
as
nn
import
torch.nn
as
nn
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.utils
import
AverageMeterGroup
from
nni.nas.pytorch.utils
import
AverageMeterGroup
from
.mutator
import
DartsMutator
from
.mutator
import
DartsMutator
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
class
DartsTrainer
(
Trainer
):
class
DartsTrainer
(
Trainer
):
def
__init__
(
self
,
model
,
loss
,
metrics
,
def
__init__
(
self
,
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
):
callbacks
=
None
,
arc_learning_rate
=
3.0E-4
,
unrolled
=
True
):
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
DartsMutator
(
model
),
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
DartsMutator
(
model
),
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
self
.
ctrl_optim
=
torch
.
optim
.
Adam
(
self
.
mutator
.
parameters
(),
3.0E-4
,
betas
=
(
0.5
,
0.999
),
self
.
ctrl_optim
=
torch
.
optim
.
Adam
(
self
.
mutator
.
parameters
(),
arc_learning_rate
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
weight_decay
=
1.0E-3
)
self
.
unrolled
=
unrolled
n_train
=
len
(
self
.
dataset_train
)
n_train
=
len
(
self
.
dataset_train
)
split
=
n_train
//
2
split
=
n_train
//
2
indices
=
list
(
range
(
n_train
))
indices
=
list
(
range
(
n_train
))
...
@@ -43,42 +43,32 @@ class DartsTrainer(Trainer):
...
@@ -43,42 +43,32 @@ class DartsTrainer(Trainer):
def
train_one_epoch
(
self
,
epoch
):
def
train_one_epoch
(
self
,
epoch
):
self
.
model
.
train
()
self
.
model
.
train
()
self
.
mutator
.
train
()
self
.
mutator
.
train
()
lr
=
self
.
optimizer
.
param_groups
[
0
][
"lr"
]
meters
=
AverageMeterGroup
()
meters
=
AverageMeterGroup
()
for
step
,
((
trn_X
,
trn_y
),
(
val_X
,
val_y
))
in
enumerate
(
zip
(
self
.
train_loader
,
self
.
valid_loader
)):
for
step
,
((
trn_X
,
trn_y
),
(
val_X
,
val_y
))
in
enumerate
(
zip
(
self
.
train_loader
,
self
.
valid_loader
)):
trn_X
,
trn_y
=
trn_X
.
to
(
self
.
device
),
trn_y
.
to
(
self
.
device
)
trn_X
,
trn_y
=
trn_X
.
to
(
self
.
device
),
trn_y
.
to
(
self
.
device
)
val_X
,
val_y
=
val_X
.
to
(
self
.
device
),
val_y
.
to
(
self
.
device
)
val_X
,
val_y
=
val_X
.
to
(
self
.
device
),
val_y
.
to
(
self
.
device
)
# backup model for hessian
# phase 1. architecture step
backup_model
=
copy
.
deepcopy
(
self
.
model
.
state_dict
())
self
.
ctrl_optim
.
zero_grad
()
# cannot deepcopy model because it will break the reference
if
self
.
unrolled
:
self
.
_unrolled_backward
(
trn_X
,
trn_y
,
val_X
,
val_y
)
else
:
self
.
_backward
(
val_X
,
val_y
)
self
.
ctrl_optim
.
step
()
# phase
1.
child network step
# phase
2:
child network step
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
self
.
mutator
.
reset
()
logits
,
loss
=
self
.
_logits_and_loss
(
trn_X
,
trn_y
)
logits
=
self
.
model
(
trn_X
)
loss
=
self
.
loss
(
logits
,
trn_y
)
loss
.
backward
()
loss
.
backward
()
# gradient clipping
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
parameters
(),
5.
)
# gradient clipping
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
parameters
(),
5.
)
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
new_model
=
copy
.
deepcopy
(
self
.
model
.
state_dict
())
# phase 2. architect step (alpha)
self
.
ctrl_optim
.
zero_grad
()
# compute unrolled loss
self
.
_unrolled_backward
(
trn_X
,
trn_y
,
val_X
,
val_y
,
backup_model
,
lr
)
self
.
ctrl_optim
.
step
()
self
.
model
.
load_state_dict
(
new_model
)
metrics
=
self
.
metrics
(
logits
,
trn_y
)
metrics
=
self
.
metrics
(
logits
,
trn_y
)
metrics
[
"loss"
]
=
loss
.
item
()
metrics
[
"loss"
]
=
loss
.
item
()
meters
.
update
(
metrics
)
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
+
1
,
logger
.
info
(
"Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
+
1
,
len
(
self
.
train_loader
),
meters
)
self
.
num_epochs
,
step
+
1
,
len
(
self
.
train_loader
),
meters
)
def
validate_one_epoch
(
self
,
epoch
):
def
validate_one_epoch
(
self
,
epoch
):
self
.
model
.
eval
()
self
.
model
.
eval
()
...
@@ -92,55 +82,92 @@ class DartsTrainer(Trainer):
...
@@ -92,55 +82,92 @@ class DartsTrainer(Trainer):
metrics
=
self
.
metrics
(
logits
,
y
)
metrics
=
self
.
metrics
(
logits
,
y
)
meters
.
update
(
metrics
)
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
+
1
,
logger
.
info
(
"Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
+
1
,
len
(
self
.
test_loader
),
meters
)
self
.
num_epochs
,
step
+
1
,
len
(
self
.
test_loader
),
meters
)
def
_logits_and_loss
(
self
,
X
,
y
):
self
.
mutator
.
reset
()
logits
=
self
.
model
(
X
)
loss
=
self
.
loss
(
logits
,
y
)
return
logits
,
loss
def
_unrolled_backward
(
self
,
trn_X
,
trn_y
,
val_X
,
val_y
,
backup_model
,
lr
):
def
_backward
(
self
,
val_X
,
val_y
):
"""
Simple backward with gradient descent
"""
_
,
loss
=
self
.
_logits_and_loss
(
val_X
,
val_y
)
loss
.
backward
()
def
_unrolled_backward
(
self
,
trn_X
,
trn_y
,
val_X
,
val_y
):
"""
"""
Compute unrolled loss and backward its gradients
Compute unrolled loss and backward its gradients
Parameters
----------
v_model: backup model before this step
lr: learning rate for virtual gradient step (same as net lr)
"""
"""
self
.
mutator
.
reset
()
backup_params
=
copy
.
deepcopy
(
tuple
(
self
.
model
.
parameters
()))
loss
=
self
.
loss
(
self
.
model
(
val_X
),
val_y
)
w_model
=
tuple
(
self
.
model
.
parameters
())
# do virtual step on training data
w_ctrl
=
tuple
(
self
.
mutator
.
parameters
())
lr
=
self
.
optimizer
.
param_groups
[
0
][
"lr"
]
momentum
=
self
.
optimizer
.
param_groups
[
0
][
"momentum"
]
weight_decay
=
self
.
optimizer
.
param_groups
[
0
][
"weight_decay"
]
self
.
_compute_virtual_model
(
trn_X
,
trn_y
,
lr
,
momentum
,
weight_decay
)
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_
,
loss
=
self
.
_logits_and_loss
(
val_X
,
val_y
)
w_model
,
w_ctrl
=
tuple
(
self
.
model
.
parameters
()),
tuple
(
self
.
mutator
.
parameters
())
w_grads
=
torch
.
autograd
.
grad
(
loss
,
w_model
+
w_ctrl
)
w_grads
=
torch
.
autograd
.
grad
(
loss
,
w_model
+
w_ctrl
)
d_model
=
w_grads
[:
len
(
w_model
)]
d_model
,
d_ctrl
=
w_grads
[:
len
(
w_model
)],
w_grads
[
len
(
w_model
):]
d_ctrl
=
w_grads
[
len
(
w_model
):]
hessian
=
self
.
_compute_hessian
(
backup_model
,
d_model
,
trn_X
,
trn_y
)
# compute hessian and final gradients
hessian
=
self
.
_compute_hessian
(
backup_params
,
d_model
,
trn_X
,
trn_y
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
param
,
d
,
h
in
zip
(
w_ctrl
,
d_ctrl
,
hessian
):
for
param
,
d
,
h
in
zip
(
w_ctrl
,
d_ctrl
,
hessian
):
# gradient = dalpha - lr * hessian
param
.
grad
=
d
-
lr
*
h
param
.
grad
=
d
-
lr
*
h
def
_compute_hessian
(
self
,
model
,
dw
,
trn_X
,
trn_y
):
# restore weights
self
.
_restore_weights
(
backup_params
)
def
_compute_virtual_model
(
self
,
X
,
y
,
lr
,
momentum
,
weight_decay
):
"""
"""
dw = dw` { L_val(w`, alpha) }
Compute unrolled weights w`
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
"""
self
.
model
.
load_state_dict
(
model
)
# don't need zero_grad, using autograd to calculate gradients
_
,
loss
=
self
.
_logits_and_loss
(
X
,
y
)
gradients
=
torch
.
autograd
.
grad
(
loss
,
self
.
model
.
parameters
())
with
torch
.
no_grad
():
for
w
,
g
in
zip
(
self
.
model
.
parameters
(),
gradients
):
m
=
self
.
optimizer
.
state
[
w
].
get
(
"momentum_buffer"
,
0.
)
w
=
w
-
lr
*
(
momentum
*
m
+
g
+
weight_decay
*
w
)
def
_restore_weights
(
self
,
backup_params
):
with
torch
.
no_grad
():
for
param
,
backup
in
zip
(
self
.
model
.
parameters
(),
backup_params
):
param
.
copy_
(
backup
)
def
_compute_hessian
(
self
,
backup_params
,
dw
,
trn_X
,
trn_y
):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self
.
_restore_weights
(
backup_params
)
norm
=
torch
.
cat
([
w
.
view
(
-
1
)
for
w
in
dw
]).
norm
()
norm
=
torch
.
cat
([
w
.
view
(
-
1
)
for
w
in
dw
]).
norm
()
eps
=
0.01
/
norm
eps
=
0.01
/
norm
if
norm
<
1E-8
:
logger
.
warning
(
"In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f."
,
norm
.
item
())
dalphas
=
[]
for
e
in
[
eps
,
-
2.
*
eps
]:
for
e
in
[
eps
,
-
2.
*
eps
]:
# w+ = w + eps*dw`, w- = w - eps*dw`
# w+ = w + eps*dw`, w- = w - eps*dw`
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
p
,
d
in
zip
(
self
.
model
.
parameters
(),
dw
):
for
p
,
d
in
zip
(
self
.
model
.
parameters
(),
dw
):
p
+=
e
ps
*
d
p
+=
e
*
d
self
.
mutator
.
reset
()
_
,
loss
=
self
.
_logits_and_loss
(
trn_X
,
trn_y
)
loss
=
self
.
loss
(
self
.
model
(
trn_X
),
trn_y
)
dalphas
.
append
(
torch
.
autograd
.
grad
(
loss
,
self
.
mutator
.
parameters
()))
if
e
>
0
:
dalpha_pos
=
torch
.
autograd
.
grad
(
loss
,
self
.
mutator
.
parameters
())
# dalpha { L_trn(w+) }
elif
e
<
0
:
dalpha_neg
=
torch
.
autograd
.
grad
(
loss
,
self
.
mutator
.
parameters
())
# dalpha { L_trn(w-) }
dalpha_pos
,
dalpha_neg
=
dalphas
# dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian
=
[(
p
-
n
)
/
2.
*
eps
for
p
,
n
in
zip
(
dalpha_pos
,
dalpha_neg
)]
hessian
=
[(
p
-
n
)
/
2.
*
eps
for
p
,
n
in
zip
(
dalpha_pos
,
dalpha_neg
)]
return
hessian
return
hessian
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
View file @
73b2221b
...
@@ -25,6 +25,7 @@ class StackedLSTMCell(nn.Module):
...
@@ -25,6 +25,7 @@ class StackedLSTMCell(nn.Module):
class
EnasMutator
(
Mutator
):
class
EnasMutator
(
Mutator
):
def
__init__
(
self
,
model
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
cell_exit_extra_step
=
False
,
def
__init__
(
self
,
model
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
cell_exit_extra_step
=
False
,
skip_target
=
0.4
,
branch_bias
=
0.25
):
skip_target
=
0.4
,
branch_bias
=
0.25
):
super
().
__init__
(
model
)
super
().
__init__
(
model
)
...
@@ -51,7 +52,7 @@ class EnasMutator(Mutator):
...
@@ -51,7 +52,7 @@ class EnasMutator(Mutator):
self
.
max_layer_choice
=
mutable
.
length
self
.
max_layer_choice
=
mutable
.
length
assert
self
.
max_layer_choice
==
mutable
.
length
,
\
assert
self
.
max_layer_choice
==
mutable
.
length
,
\
"ENAS mutator requires all layer choice have the same number of candidates."
"ENAS mutator requires all layer choice have the same number of candidates."
#
NOTE(yuge): We might implement an interface
la
t
er
. Judging by key now
.
#
We are judging by keys and module types to add biases to
la
y
er
choices. Needs refactor
.
if
"reduce"
in
mutable
.
key
:
if
"reduce"
in
mutable
.
key
:
def
is_conv
(
choice
):
def
is_conv
(
choice
):
return
"conv"
in
str
(
type
(
choice
)).
lower
()
return
"conv"
in
str
(
type
(
choice
)).
lower
()
...
...
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
View file @
73b2221b
...
@@ -6,9 +6,7 @@ from nni.nas.pytorch.trainer import Trainer
...
@@ -6,9 +6,7 @@ from nni.nas.pytorch.trainer import Trainer
from
nni.nas.pytorch.utils
import
AverageMeterGroup
from
nni.nas.pytorch.utils
import
AverageMeterGroup
from
.mutator
import
EnasMutator
from
.mutator
import
EnasMutator
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
class
EnasTrainer
(
Trainer
):
class
EnasTrainer
(
Trainer
):
...
@@ -75,8 +73,8 @@ class EnasTrainer(Trainer):
...
@@ -75,8 +73,8 @@ class EnasTrainer(Trainer):
meters
.
update
(
metrics
)
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"Model Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
,
logger
.
info
(
"Model Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
,
len
(
self
.
train_loader
),
meters
)
self
.
num_epochs
,
step
+
1
,
len
(
self
.
train_loader
),
meters
)
# Train sampler (mutator)
# Train sampler (mutator)
self
.
model
.
eval
()
self
.
model
.
eval
()
...
@@ -114,8 +112,8 @@ class EnasTrainer(Trainer):
...
@@ -114,8 +112,8 @@ class EnasTrainer(Trainer):
self
.
mutator_optim
.
zero_grad
()
self
.
mutator_optim
.
zero_grad
()
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"RL Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
,
self
.
num_epochs
,
logger
.
info
(
"RL Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
+
1
,
self
.
num_epochs
,
mutator_step
//
self
.
mutator_steps_aggregate
,
self
.
mutator_steps
,
meters
)
mutator_step
//
self
.
mutator_steps_aggregate
+
1
,
self
.
mutator_steps
,
meters
)
mutator_step
+=
1
mutator_step
+=
1
if
mutator_step
>=
total_mutator_steps
:
if
mutator_step
>=
total_mutator_steps
:
break
break
...
...
src/sdk/pynni/nni/nas/pytorch/fixed.py
View file @
73b2221b
...
@@ -14,11 +14,11 @@ class FixedArchitecture(Mutator):
...
@@ -14,11 +14,11 @@ class FixedArchitecture(Mutator):
Parameters
Parameters
----------
----------
model: nn.Module
model
: nn.Module
A mutable network.
A mutable network.
fixed_arc: str or dict
fixed_arc
: str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict: bool
strict
: bool
Force everything that appears in `fixed_arc` to be used at least once.
Force everything that appears in `fixed_arc` to be used at least once.
"""
"""
super
().
__init__
(
model
)
super
().
__init__
(
model
)
...
@@ -55,11 +55,11 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
...
@@ -55,11 +55,11 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
Parameters
Parameters
----------
----------
model: torch.nn.Module
model
: torch.nn.Module
Model with mutables.
Model with mutables.
fixed_arc_path: str
fixed_arc_path
: str
Path to the JSON that stores the architecture.
Path to the JSON that stores the architecture.
device: torch.device
device
: torch.device
Architecture weights will be transfered to `device`.
Architecture weights will be transfered to `device`.
Returns
Returns
...
@@ -76,3 +76,4 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
...
@@ -76,3 +76,4 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
.
to
(
device
)
architecture
.
to
(
device
)
architecture
.
reset
()
architecture
.
reset
()
return
architecture
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
73b2221b
...
@@ -39,6 +39,9 @@ class Mutable(nn.Module):
...
@@ -39,6 +39,9 @@ class Mutable(nn.Module):
return
super
().
__call__
(
*
args
,
**
kwargs
)
return
super
().
__call__
(
*
args
,
**
kwargs
)
def
set_mutator
(
self
,
mutator
):
def
set_mutator
(
self
,
mutator
):
if
"mutator"
in
self
.
__dict__
:
raise
RuntimeError
(
"`set_mutator` is called more than once. Did you parse the search space multiple times? "
"Or did you apply multiple fixed architectures?"
)
self
.
__dict__
[
"mutator"
]
=
mutator
self
.
__dict__
[
"mutator"
]
=
mutator
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
...
@@ -68,9 +71,10 @@ class Mutable(nn.Module):
...
@@ -68,9 +71,10 @@ class Mutable(nn.Module):
class
MutableScope
(
Mutable
):
class
MutableScope
(
Mutable
):
"""
"""
Mutable scope
label
s a subgraph/submodule to help mutators make better decisions.
Mutable scope
mark
s a subgraph/submodule to help mutators make better decisions.
Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope``
Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope``
and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update.
and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update.
MutableScope are also mutables that are listed in the mutables (search space).
"""
"""
def
__init__
(
self
,
key
):
def
__init__
(
self
,
key
):
...
@@ -86,7 +90,7 @@ class MutableScope(Mutable):
...
@@ -86,7 +90,7 @@ class MutableScope(Mutable):
class
LayerChoice
(
Mutable
):
class
LayerChoice
(
Mutable
):
def
__init__
(
self
,
op_candidates
,
reduction
=
"
mean
"
,
return_mask
=
False
,
key
=
None
):
def
__init__
(
self
,
op_candidates
,
reduction
=
"
sum
"
,
return_mask
=
False
,
key
=
None
):
super
().
__init__
(
key
=
key
)
super
().
__init__
(
key
=
key
)
self
.
length
=
len
(
op_candidates
)
self
.
length
=
len
(
op_candidates
)
self
.
choices
=
nn
.
ModuleList
(
op_candidates
)
self
.
choices
=
nn
.
ModuleList
(
op_candidates
)
...
@@ -117,25 +121,25 @@ class InputChoice(Mutable):
...
@@ -117,25 +121,25 @@ class InputChoice(Mutable):
NO_KEY
=
""
NO_KEY
=
""
def
__init__
(
self
,
n_candidates
=
None
,
choose_from
=
None
,
n_chosen
=
None
,
def
__init__
(
self
,
n_candidates
=
None
,
choose_from
=
None
,
n_chosen
=
None
,
reduction
=
"
mean
"
,
return_mask
=
False
,
key
=
None
):
reduction
=
"
sum
"
,
return_mask
=
False
,
key
=
None
):
"""
"""
Initialization.
Initialization.
Parameters
Parameters
----------
----------
n_candidates: int
n_candidates
: int
Number of inputs to choose from.
Number of inputs to choose from.
choose_from: list of str
choose_from
: list of str
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled.
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled.
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates`
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates`
number of empty string.
number of empty string.
n_chosen: int
n_chosen
: int
Recommended inputs to choose. If None, mutator is instructed to select any.
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction: str
reduction
: str
`mean`, `concat`, `sum` or `none`.
`mean`, `concat`, `sum` or `none`.
return_mask: bool
return_mask
: bool
If `return_mask`, return output tensor and a mask. Otherwise return tensor only.
If `return_mask`, return output tensor and a mask. Otherwise return tensor only.
key: str
key
: str
Key of the input choice.
Key of the input choice.
"""
"""
super
().
__init__
(
key
=
key
)
super
().
__init__
(
key
=
key
)
...
@@ -163,7 +167,7 @@ class InputChoice(Mutable):
...
@@ -163,7 +167,7 @@ class InputChoice(Mutable):
Parameters
Parameters
----------
----------
optional_inputs: list or dict
optional_inputs
: list or dict
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
`choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as
`choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as
`choose_from`.
`choose_from`.
...
...
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
73b2221b
import
logging
import
torch
import
torch
from
nni.nas.pytorch.base_mutator
import
BaseMutator
from
nni.nas.pytorch.base_mutator
import
BaseMutator
logger
=
logging
.
getLogger
(
__name__
)
class
Mutator
(
BaseMutator
):
class
Mutator
(
BaseMutator
):
...
@@ -60,8 +64,8 @@ class Mutator(BaseMutator):
...
@@ -60,8 +64,8 @@ class Mutator(BaseMutator):
Parameters
Parameters
----------
----------
mutable: LayerChoice
mutable
: LayerChoice
inputs: list of torch.Tensor
inputs
: list of torch.Tensor
Returns
Returns
-------
-------
...
@@ -85,9 +89,9 @@ class Mutator(BaseMutator):
...
@@ -85,9 +89,9 @@ class Mutator(BaseMutator):
Parameters
Parameters
----------
----------
mutable: InputChoice
mutable
: InputChoice
tensor_list: list of torch.Tensor
tensor_list
: list of torch.Tensor
tags: list of string
tags
: list of string
Returns
Returns
-------
-------
...
@@ -108,7 +112,7 @@ class Mutator(BaseMutator):
...
@@ -108,7 +112,7 @@ class Mutator(BaseMutator):
return
out
return
out
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
tensor_list
==
"none"
:
if
reduction_type
==
"none"
:
return
tensor_list
return
tensor_list
if
not
tensor_list
:
if
not
tensor_list
:
return
None
# empty. return None for now
return
None
# empty. return None for now
...
@@ -129,12 +133,14 @@ class Mutator(BaseMutator):
...
@@ -129,12 +133,14 @@ class Mutator(BaseMutator):
Parameters
Parameters
----------
----------
mutable: Mutable
mutable
: Mutable
Returns
Returns
-------
-------
any
object
"""
"""
if
mutable
.
key
not
in
self
.
_cache
:
if
mutable
.
key
not
in
self
.
_cache
:
raise
ValueError
(
"
\"
{}
\"
not found in decision cache."
.
format
(
mutable
.
key
))
raise
ValueError
(
"
\"
{}
\"
not found in decision cache."
.
format
(
mutable
.
key
))
return
self
.
_cache
[
mutable
.
key
]
result
=
self
.
_cache
[
mutable
.
key
]
logger
.
debug
(
"Decision %s: %s"
,
mutable
.
key
,
result
)
return
result
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
View file @
73b2221b
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
from
torch.nn
import
functional
as
F
import
torch.nn
.
functional
as
F
from
nni.nas.pytorch.darts
import
DartsMutator
from
nni.nas.pytorch.darts
import
DartsMutator
from
nni.nas.pytorch.mutables
import
LayerChoice
from
nni.nas.pytorch.mutables
import
LayerChoice
...
...
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
View file @
73b2221b
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
from
nni.nas.pytorch.callbacks
import
L
earningRate
Scheduler
from
nni.nas.pytorch.callbacks
import
L
R
Scheduler
Callback
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.trainer
import
BaseTrainer
from
nni.nas.pytorch.trainer
import
BaseTrainer
...
@@ -50,7 +50,7 @@ class PdartsTrainer(BaseTrainer):
...
@@ -50,7 +50,7 @@ class PdartsTrainer(BaseTrainer):
darts_callbacks
=
[]
darts_callbacks
=
[]
if
lr_scheduler
is
not
None
:
if
lr_scheduler
is
not
None
:
darts_callbacks
.
append
(
L
earningRate
Scheduler
(
lr_scheduler
))
darts_callbacks
.
append
(
L
R
Scheduler
Callback
(
lr_scheduler
))
self
.
trainer
=
DartsTrainer
(
model
,
mutator
=
self
.
mutator
,
loss
=
criterion
,
optimizer
=
optim
,
self
.
trainer
=
DartsTrainer
(
model
,
mutator
=
self
.
mutator
,
loss
=
criterion
,
optimizer
=
optim
,
callbacks
=
darts_callbacks
,
**
self
.
darts_parameters
)
callbacks
=
darts_callbacks
,
**
self
.
darts_parameters
)
...
...
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
73b2221b
...
@@ -24,6 +24,40 @@ class TorchTensorEncoder(json.JSONEncoder):
...
@@ -24,6 +24,40 @@ class TorchTensorEncoder(json.JSONEncoder):
class
Trainer
(
BaseTrainer
):
class
Trainer
(
BaseTrainer
):
def
__init__
(
self
,
model
,
mutator
,
loss
,
metrics
,
optimizer
,
num_epochs
,
def
__init__
(
self
,
model
,
mutator
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
):
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
):
"""
Trainer initialization.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : BaseMutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
dataset_train : torch.utils.data.Dataset
Dataset of training.
dataset_valid : torch.utils.data.Dataset
Dataset of validation/testing.
batch_size : int
Batch size.
workers : int
Number of workers used in data preprocessing.
device : torch.device
Device object. Either `torch.device("cuda")` or torch.device("cpu")`. When `None`, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
self
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
if
device
is
None
else
device
self
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
if
device
is
None
else
device
self
.
model
=
model
self
.
model
=
model
self
.
mutator
=
mutator
self
.
mutator
=
mutator
...
...
src/sdk/pynni/nni/nas/pytorch/utils.py
View file @
73b2221b
...
@@ -28,6 +28,16 @@ class AverageMeter:
...
@@ -28,6 +28,16 @@ class AverageMeter:
"""Computes and stores the average and current value"""
"""Computes and stores the average and current value"""
def
__init__
(
self
,
name
,
fmt
=
':f'
):
def
__init__
(
self
,
name
,
fmt
=
':f'
):
"""
Initialization of AverageMeter
Parameters
----------
name : str
Name to display.
fmt : str
Format string to print the values.
"""
self
.
name
=
name
self
.
name
=
name
self
.
fmt
=
fmt
self
.
fmt
=
fmt
self
.
reset
()
self
.
reset
()
...
@@ -78,12 +88,12 @@ class StructuredMutableTreeNode:
...
@@ -78,12 +88,12 @@ class StructuredMutableTreeNode:
Parameters
Parameters
----------
----------
order: str
order
: str
pre or post. If pre, current mutable is yield before children. Otherwise after.
pre or post. If pre, current mutable is yield before children. Otherwise after.
deduplicate: bool
deduplicate
: bool
If true, mutables with the same key will not appear after the first appearance.
If true, mutables with the same key will not appear after the first appearance.
memo: dict
memo
: dict
An auxiliary
variable to make deduplicate happen
.
An auxiliary
dict that memorize keys seen before, so that deduplication is possible
.
Returns
Returns
-------
-------
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment