Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a0fd0036
Unverified
Commit
a0fd0036
authored
Aug 01, 2022
by
Yuge Zhang
Committed by
GitHub
Aug 01, 2022
Browse files
Merge pull request #5036 from microsoft/promote-retiarii-to-nas
[DO NOT SQUASH] Promote retiarii to NAS
parents
d6dcb483
bc6d8796
Changes
239
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
4070 deletions
+0
-4070
nni/algorithms/nas/pytorch/cdarts/mutator.py
nni/algorithms/nas/pytorch/cdarts/mutator.py
+0
-143
nni/algorithms/nas/pytorch/cdarts/trainer.py
nni/algorithms/nas/pytorch/cdarts/trainer.py
+0
-275
nni/algorithms/nas/pytorch/cdarts/utils.py
nni/algorithms/nas/pytorch/cdarts/utils.py
+0
-76
nni/algorithms/nas/pytorch/classic_nas/mutator.py
nni/algorithms/nas/pytorch/classic_nas/mutator.py
+0
-221
nni/algorithms/nas/pytorch/cream/trainer.py
nni/algorithms/nas/pytorch/cream/trainer.py
+0
-403
nni/algorithms/nas/pytorch/cream/utils.py
nni/algorithms/nas/pytorch/cream/utils.py
+0
-37
nni/algorithms/nas/pytorch/darts/mutator.py
nni/algorithms/nas/pytorch/darts/mutator.py
+0
-85
nni/algorithms/nas/pytorch/darts/trainer.py
nni/algorithms/nas/pytorch/darts/trainer.py
+0
-214
nni/algorithms/nas/pytorch/enas/trainer.py
nni/algorithms/nas/pytorch/enas/trainer.py
+0
-209
nni/algorithms/nas/pytorch/fbnet/__init__.py
nni/algorithms/nas/pytorch/fbnet/__init__.py
+0
-14
nni/algorithms/nas/pytorch/fbnet/mutator.py
nni/algorithms/nas/pytorch/fbnet/mutator.py
+0
-268
nni/algorithms/nas/pytorch/fbnet/trainer.py
nni/algorithms/nas/pytorch/fbnet/trainer.py
+0
-413
nni/algorithms/nas/pytorch/fbnet/utils.py
nni/algorithms/nas/pytorch/fbnet/utils.py
+0
-433
nni/algorithms/nas/pytorch/pdarts/mutator.py
nni/algorithms/nas/pytorch/pdarts/mutator.py
+0
-93
nni/algorithms/nas/pytorch/pdarts/trainer.py
nni/algorithms/nas/pytorch/pdarts/trainer.py
+0
-86
nni/algorithms/nas/pytorch/proxylessnas/__init__.py
nni/algorithms/nas/pytorch/proxylessnas/__init__.py
+0
-5
nni/algorithms/nas/pytorch/proxylessnas/mutator.py
nni/algorithms/nas/pytorch/proxylessnas/mutator.py
+0
-478
nni/algorithms/nas/pytorch/proxylessnas/trainer.py
nni/algorithms/nas/pytorch/proxylessnas/trainer.py
+0
-500
nni/algorithms/nas/pytorch/proxylessnas/utils.py
nni/algorithms/nas/pytorch/proxylessnas/utils.py
+0
-78
nni/algorithms/nas/pytorch/random/mutator.py
nni/algorithms/nas/pytorch/random/mutator.py
+0
-39
No files found.
nni/algorithms/nas/pytorch/cdarts/mutator.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
from
apex.parallel
import
DistributedDataParallel
# pylint: disable=import-error
from
nni.algorithms.nas.pytorch.darts
import
DartsMutator
# pylint: disable=wrong-import-order
from
nni.nas.pytorch.mutables
import
LayerChoice
# pylint: disable=wrong-import-order
from
nni.nas.pytorch.mutator
import
Mutator
# pylint: disable=wrong-import-order
class
RegularizedDartsMutator
(
DartsMutator
):
"""
This is :class:`~nni.algorithms.nas.pytorch.darts.DartsMutator` basically, with two differences.
1. Choices can be cut (bypassed). This is done by ``cut_choices``. Cutted choices will not be used in
forward pass and thus consumes no memory.
2. Regularization on choices, to prevent the mutator from overfitting on some choices.
"""
def
reset
(
self
):
"""
Warnings
--------
Renamed :func:`~reset_with_loss` to return regularization loss on reset.
"""
raise
ValueError
(
"You should probably call `reset_with_loss`."
)
def
cut_choices
(
self
,
cut_num
=
2
):
"""
Cut the choices with the smallest weights.
``cut_num`` should be the accumulative number of cutting, e.g., if first time cutting
is 2, the second time should be 4 to cut another two.
Parameters
----------
cut_num : int
Number of choices to cut, so far.
Warnings
--------
Though the parameters are set to :math:`-\infty` to be bypassed, they will still receive gradient of 0,
which introduced ``nan`` problem when calling ``optimizer.step()``. To solve this issue, a simple way is to
reset nan to :math:`-\infty` each time after the parameters are updated.
"""
# `cut_choices` is implemented but not used in current implementation of CdartsTrainer
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
_
,
idx
=
torch
.
topk
(
-
self
.
choices
[
mutable
.
key
],
cut_num
)
with
torch
.
no_grad
():
for
i
in
idx
:
self
.
choices
[
mutable
.
key
][
i
]
=
-
float
(
"inf"
)
def
reset_with_loss
(
self
):
"""
Resample and return loss. If loss is 0, to avoid device issue, it will return ``None``.
Currently loss penalty are proportional to the L1-norm of parameters corresponding
to modules if their type name contains certain substrings. These substrings include: ``poolwithoutbn``,
``identity``, ``dilconv``.
"""
self
.
_cache
,
reg_loss
=
self
.
sample_search
()
return
reg_loss
def
sample_search
(
self
):
result
=
super
().
sample_search
()
loss
=
[]
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
def
need_reg
(
choice
):
return
any
(
t
in
str
(
type
(
choice
)).
lower
()
for
t
in
[
"poolwithoutbn"
,
"identity"
,
"dilconv"
])
for
i
,
choice
in
enumerate
(
mutable
.
choices
):
if
need_reg
(
choice
):
norm
=
torch
.
abs
(
self
.
choices
[
mutable
.
key
][
i
])
if
norm
<
1E10
:
loss
.
append
(
norm
)
if
not
loss
:
return
result
,
None
return
result
,
sum
(
loss
)
def
export
(
self
,
logger
=
None
):
"""
Export an architecture with logger. Genotype will be printed with logger.
Returns
-------
dict
A mapping from mutable keys to decisions.
"""
result
=
self
.
sample_final
()
if
hasattr
(
self
.
model
,
"plot_genotype"
)
and
logger
is
not
None
:
genotypes
=
self
.
model
.
plot_genotype
(
result
,
logger
)
return
result
,
genotypes
class
RegularizedMutatorParallel
(
DistributedDataParallel
):
"""
Parallelize :class:`~RegularizedDartsMutator`.
This makes :func:`~RegularizedDartsMutator.reset_with_loss` method parallelized,
also allowing :func:`~RegularizedDartsMutator.cut_choices` and :func:`~RegularizedDartsMutator.export`
to be easily accessible.
"""
def
reset_with_loss
(
self
):
"""
Parallelized :func:`~RegularizedDartsMutator.reset_with_loss`.
"""
result
=
self
.
module
.
reset_with_loss
()
self
.
callback_queued
=
False
return
result
def
cut_choices
(
self
,
*
args
,
**
kwargs
):
"""
Parallelized :func:`~RegularizedDartsMutator.cut_choices`.
"""
self
.
module
.
cut_choices
(
*
args
,
**
kwargs
)
def
export
(
self
,
logger
):
"""
Parallelized :func:`~RegularizedDartsMutator.export`.
"""
return
self
.
module
.
export
(
logger
)
class
DartsDiscreteMutator
(
Mutator
):
"""
A mutator that applies the final sampling result of a parent mutator on another model to train.
Parameters
----------
model : nn.Module
The model to apply the mutator.
parent_mutator : nni.nas.pytorch.mutator.Mutator
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
"""
def
__init__
(
self
,
model
,
parent_mutator
):
super
().
__init__
(
model
)
self
.
__dict__
[
"parent_mutator"
]
=
parent_mutator
# avoid parameters to be included
def
sample_search
(
self
):
return
self
.
parent_mutator
.
sample_final
()
nni/algorithms/nas/pytorch/cdarts/trainer.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
logging
import
os
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
apex
# pylint: disable=import-error
from
apex.parallel
import
DistributedDataParallel
# pylint: disable=import-error
from
.mutator
import
RegularizedDartsMutator
,
RegularizedMutatorParallel
,
DartsDiscreteMutator
# pylint: disable=wrong-import-order
from
nni.nas.pytorch.utils
import
AverageMeterGroup
# pylint: disable=wrong-import-order
from
.utils
import
CyclicIterator
,
TorchTensorEncoder
,
accuracy
,
reduce_metrics
PHASE_SMALL
=
"small"
PHASE_LARGE
=
"large"
class
InteractiveKLLoss
(
nn
.
Module
):
def
__init__
(
self
,
temperature
):
super
().
__init__
()
self
.
temperature
=
temperature
# self.kl_loss = nn.KLDivLoss(reduction = 'batchmean')
self
.
kl_loss
=
nn
.
KLDivLoss
()
def
forward
(
self
,
student
,
teacher
):
return
self
.
kl_loss
(
F
.
log_softmax
(
student
/
self
.
temperature
,
dim
=
1
),
F
.
softmax
(
teacher
/
self
.
temperature
,
dim
=
1
))
class
CdartsTrainer
(
object
):
"""
CDARTS trainer.
Parameters
----------
model_small : nn.Module
PyTorch model to be trained. This is the search network of CDARTS.
model_large : nn.Module
PyTorch model to be trained. This is the evaluation network of CDARTS.
criterion : callable
Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
loaders : list of torch.utils.data.DataLoader
List of train data and valid data loaders, for training weights and architecture weights respectively.
samplers : list of torch.utils.data.Sampler
List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
logger : logging.Logger
The logger for logging. Will use nni logger by default (if logger is ``None``).
regular_coeff : float
The coefficient of regular loss.
regular_ratio : float
The ratio of regular loss.
warmup_epochs : int
The epochs to warmup the search network
fix_head : bool
``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
epochs : int
Number of epochs planned for training.
steps_per_epoch : int
Steps of one epoch.
loss_alpha : float
The loss coefficient.
loss_T : float
The loss coefficient.
distributed : bool
``True`` if using distributed training, else non-distributed training.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping for weights.
interactive_type : string
``kl`` or ``smoothl1``.
output_path : string
Log storage path.
w_lr : float
Learning rate of the search network parameters.
w_momentum : float
Momentum of the search and the evaluation network.
w_weight_decay : float
The weight decay the search and the evaluation network parameters.
alpha_lr : float
Learning rate of the architecture parameters.
alpha_weight_decay : float
The weight decay the architecture parameters.
nasnet_lr : float
Learning rate of the evaluation network parameters.
local_rank : int
The number of thread.
share_module : bool
``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
"""
def
__init__
(
self
,
model_small
,
model_large
,
criterion
,
loaders
,
samplers
,
logger
=
None
,
regular_coeff
=
5
,
regular_ratio
=
0.2
,
warmup_epochs
=
2
,
fix_head
=
True
,
epochs
=
32
,
steps_per_epoch
=
None
,
loss_alpha
=
2
,
loss_T
=
2
,
distributed
=
True
,
log_frequency
=
10
,
grad_clip
=
5.0
,
interactive_type
=
'kl'
,
output_path
=
'./outputs'
,
w_lr
=
0.2
,
w_momentum
=
0.9
,
w_weight_decay
=
3e-4
,
alpha_lr
=
0.2
,
alpha_weight_decay
=
1e-4
,
nasnet_lr
=
0.2
,
local_rank
=
0
,
share_module
=
True
):
if
logger
is
None
:
logger
=
logging
.
getLogger
(
__name__
)
train_loader
,
valid_loader
=
loaders
train_sampler
,
valid_sampler
=
samplers
self
.
train_loader
=
CyclicIterator
(
train_loader
,
train_sampler
,
distributed
)
self
.
valid_loader
=
CyclicIterator
(
valid_loader
,
valid_sampler
,
distributed
)
self
.
regular_coeff
=
regular_coeff
self
.
regular_ratio
=
regular_ratio
self
.
warmup_epochs
=
warmup_epochs
self
.
fix_head
=
fix_head
self
.
epochs
=
epochs
self
.
steps_per_epoch
=
steps_per_epoch
if
self
.
steps_per_epoch
is
None
:
self
.
steps_per_epoch
=
min
(
len
(
self
.
train_loader
),
len
(
self
.
valid_loader
))
self
.
loss_alpha
=
loss_alpha
self
.
grad_clip
=
grad_clip
if
interactive_type
==
"kl"
:
self
.
interactive_loss
=
InteractiveKLLoss
(
loss_T
)
elif
interactive_type
==
"smoothl1"
:
self
.
interactive_loss
=
nn
.
SmoothL1Loss
()
self
.
loss_T
=
loss_T
self
.
distributed
=
distributed
self
.
log_frequency
=
log_frequency
self
.
main_proc
=
not
distributed
or
local_rank
==
0
self
.
logger
=
logger
self
.
checkpoint_dir
=
output_path
if
self
.
main_proc
:
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
if
distributed
:
torch
.
distributed
.
barrier
()
self
.
model_small
=
model_small
self
.
model_large
=
model_large
if
self
.
fix_head
:
for
param
in
self
.
model_small
.
aux_head
.
parameters
():
param
.
requires_grad
=
False
for
param
in
self
.
model_large
.
aux_head
.
parameters
():
param
.
requires_grad
=
False
self
.
mutator_small
=
RegularizedDartsMutator
(
self
.
model_small
).
cuda
()
self
.
mutator_large
=
DartsDiscreteMutator
(
self
.
model_large
,
self
.
mutator_small
).
cuda
()
self
.
criterion
=
criterion
self
.
optimizer_small
=
torch
.
optim
.
SGD
(
self
.
model_small
.
parameters
(),
w_lr
,
momentum
=
w_momentum
,
weight_decay
=
w_weight_decay
)
self
.
optimizer_large
=
torch
.
optim
.
SGD
(
self
.
model_large
.
parameters
(),
nasnet_lr
,
momentum
=
w_momentum
,
weight_decay
=
w_weight_decay
)
self
.
optimizer_alpha
=
torch
.
optim
.
Adam
(
self
.
mutator_small
.
parameters
(),
alpha_lr
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
alpha_weight_decay
)
if
distributed
:
apex
.
parallel
.
convert_syncbn_model
(
self
.
model_small
)
apex
.
parallel
.
convert_syncbn_model
(
self
.
model_large
)
self
.
model_small
=
DistributedDataParallel
(
self
.
model_small
,
delay_allreduce
=
True
)
self
.
model_large
=
DistributedDataParallel
(
self
.
model_large
,
delay_allreduce
=
True
)
self
.
mutator_small
=
RegularizedMutatorParallel
(
self
.
mutator_small
,
delay_allreduce
=
True
)
if
share_module
:
self
.
model_small
.
callback_queued
=
True
self
.
model_large
.
callback_queued
=
True
# mutator large never gets optimized, so do not need parallelized
def
_warmup
(
self
,
phase
,
epoch
):
assert
phase
in
[
PHASE_SMALL
,
PHASE_LARGE
]
if
phase
==
PHASE_SMALL
:
model
,
optimizer
=
self
.
model_small
,
self
.
optimizer_small
elif
phase
==
PHASE_LARGE
:
model
,
optimizer
=
self
.
model_large
,
self
.
optimizer_large
model
.
train
()
meters
=
AverageMeterGroup
()
for
step
in
range
(
self
.
steps_per_epoch
):
x
,
y
=
next
(
self
.
train_loader
)
x
,
y
=
x
.
cuda
(),
y
.
cuda
()
optimizer
.
zero_grad
()
logits_main
,
_
=
model
(
x
)
loss
=
self
.
criterion
(
logits_main
,
y
)
loss
.
backward
()
self
.
_clip_grad_norm
(
model
)
optimizer
.
step
()
prec1
,
prec5
=
accuracy
(
logits_main
,
y
,
topk
=
(
1
,
5
))
metrics
=
{
"prec1"
:
prec1
,
"prec5"
:
prec5
,
"loss"
:
loss
}
metrics
=
reduce_metrics
(
metrics
,
self
.
distributed
)
meters
.
update
(
metrics
)
if
self
.
main_proc
and
(
step
%
self
.
log_frequency
==
0
or
step
+
1
==
self
.
steps_per_epoch
):
self
.
logger
.
info
(
"Epoch [%d/%d] Step [%d/%d] (%s) %s"
,
epoch
+
1
,
self
.
epochs
,
step
+
1
,
self
.
steps_per_epoch
,
phase
,
meters
)
def
_clip_grad_norm
(
self
,
model
):
if
isinstance
(
model
,
DistributedDataParallel
):
nn
.
utils
.
clip_grad_norm_
(
model
.
module
.
parameters
(),
self
.
grad_clip
)
else
:
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
self
.
grad_clip
)
def
_reset_nan
(
self
,
parameters
):
with
torch
.
no_grad
():
for
param
in
parameters
:
for
i
,
p
in
enumerate
(
param
):
if
p
!=
p
:
# equivalent to `isnan(p)`
param
[
i
]
=
float
(
"-inf"
)
def
_joint_train
(
self
,
epoch
):
self
.
model_large
.
train
()
self
.
model_small
.
train
()
meters
=
AverageMeterGroup
()
for
step
in
range
(
self
.
steps_per_epoch
):
trn_x
,
trn_y
=
next
(
self
.
train_loader
)
val_x
,
val_y
=
next
(
self
.
valid_loader
)
trn_x
,
trn_y
=
trn_x
.
cuda
(),
trn_y
.
cuda
()
val_x
,
val_y
=
val_x
.
cuda
(),
val_y
.
cuda
()
# step 1. optimize architecture
self
.
optimizer_alpha
.
zero_grad
()
self
.
optimizer_large
.
zero_grad
()
reg_decay
=
max
(
self
.
regular_coeff
*
(
1
-
float
(
epoch
-
self
.
warmup_epochs
)
/
(
(
self
.
epochs
-
self
.
warmup_epochs
)
*
self
.
regular_ratio
)),
0
)
loss_regular
=
self
.
mutator_small
.
reset_with_loss
()
if
loss_regular
:
loss_regular
*=
reg_decay
logits_search
,
emsemble_logits_search
=
self
.
model_small
(
val_x
)
logits_main
,
emsemble_logits_main
=
self
.
model_large
(
val_x
)
loss_cls
=
(
self
.
criterion
(
logits_search
,
val_y
)
+
self
.
criterion
(
logits_main
,
val_y
))
/
self
.
loss_alpha
loss_interactive
=
self
.
interactive_loss
(
emsemble_logits_search
,
emsemble_logits_main
)
*
(
self
.
loss_T
**
2
)
*
self
.
loss_alpha
loss
=
loss_cls
+
loss_interactive
+
loss_regular
loss
.
backward
()
self
.
_clip_grad_norm
(
self
.
model_large
)
self
.
optimizer_large
.
step
()
self
.
optimizer_alpha
.
step
()
# NOTE: need to call here `self._reset_nan(self.mutator_small.parameters())` if `cut_choices`
# step 2. optimize op weights
self
.
optimizer_small
.
zero_grad
()
with
torch
.
no_grad
():
# resample architecture since parameters have been changed
self
.
mutator_small
.
reset_with_loss
()
logits_search_train
,
_
=
self
.
model_small
(
trn_x
)
loss_weight
=
self
.
criterion
(
logits_search_train
,
trn_y
)
loss_weight
.
backward
()
self
.
_clip_grad_norm
(
self
.
model_small
)
self
.
optimizer_small
.
step
()
metrics
=
{
"loss_cls"
:
loss_cls
,
"loss_interactive"
:
loss_interactive
,
"loss_regular"
:
loss_regular
,
"loss_weight"
:
loss_weight
}
metrics
=
reduce_metrics
(
metrics
,
self
.
distributed
)
meters
.
update
(
metrics
)
if
self
.
main_proc
and
(
step
%
self
.
log_frequency
==
0
or
step
+
1
==
self
.
steps_per_epoch
):
self
.
logger
.
info
(
"Epoch [%d/%d] Step [%d/%d] (joint) %s"
,
epoch
+
1
,
self
.
epochs
,
step
+
1
,
self
.
steps_per_epoch
,
meters
)
def
train
(
self
):
for
epoch
in
range
(
self
.
epochs
):
if
epoch
<
self
.
warmup_epochs
:
with
torch
.
no_grad
():
# otherwise grads will be retained on the architecture params
self
.
mutator_small
.
reset_with_loss
()
self
.
_warmup
(
PHASE_SMALL
,
epoch
)
else
:
with
torch
.
no_grad
():
self
.
mutator_large
.
reset
()
self
.
_warmup
(
PHASE_LARGE
,
epoch
)
self
.
_joint_train
(
epoch
)
self
.
export
(
os
.
path
.
join
(
self
.
checkpoint_dir
,
"epoch_{:02d}.json"
.
format
(
epoch
)),
os
.
path
.
join
(
self
.
checkpoint_dir
,
"epoch_{:02d}.genotypes"
.
format
(
epoch
)))
def
export
(
self
,
file
,
genotype_file
):
if
self
.
main_proc
:
mutator_export
,
genotypes
=
self
.
mutator_small
.
export
(
self
.
logger
)
with
open
(
file
,
"w"
)
as
f
:
json
.
dump
(
mutator_export
,
f
,
indent
=
2
,
sort_keys
=
True
,
cls
=
TorchTensorEncoder
)
with
open
(
genotype_file
,
"w"
)
as
f
:
f
.
write
(
str
(
genotypes
))
nni/algorithms/nas/pytorch/cdarts/utils.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
os
import
torch
import
torch.distributed
as
dist
class
CyclicIterator
:
def
__init__
(
self
,
loader
,
sampler
,
distributed
):
self
.
loader
=
loader
self
.
sampler
=
sampler
self
.
epoch
=
0
self
.
distributed
=
distributed
self
.
_next_epoch
()
def
_next_epoch
(
self
):
if
self
.
distributed
:
self
.
sampler
.
set_epoch
(
self
.
epoch
)
self
.
iterator
=
iter
(
self
.
loader
)
self
.
epoch
+=
1
def
__len__
(
self
):
return
len
(
self
.
loader
)
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
try
:
return
next
(
self
.
iterator
)
except
StopIteration
:
self
.
_next_epoch
()
return
next
(
self
.
iterator
)
class
TorchTensorEncoder
(
json
.
JSONEncoder
):
def
default
(
self
,
o
):
# pylint: disable=method-hidden
if
isinstance
(
o
,
torch
.
Tensor
):
return
o
.
tolist
()
return
super
().
default
(
o
)
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
""" Computes the precision@k for the specified values of k """
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
# one-hot case
if
target
.
ndimension
()
>
1
:
target
=
target
.
max
(
1
)[
1
]
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:
k
].
reshape
(
-
1
).
float
().
sum
(
0
)
res
.
append
(
correct_k
.
mul_
(
1.0
/
batch_size
))
return
res
def
reduce_tensor
(
tensor
):
rt
=
tensor
.
clone
()
dist
.
all_reduce
(
rt
,
op
=
dist
.
ReduceOp
.
SUM
)
rt
/=
float
(
os
.
environ
[
"WORLD_SIZE"
])
return
rt
def
reduce_metrics
(
metrics
,
distributed
=
False
):
if
distributed
:
return
{
k
:
reduce_tensor
(
v
).
item
()
for
k
,
v
in
metrics
.
items
()}
return
{
k
:
v
.
item
()
for
k
,
v
in
metrics
.
items
()}
nni/algorithms/nas/pytorch/classic_nas/mutator.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
logging
import
os
import
sys
import
torch
import
nni
from
nni.runtime.env_vars
import
trial_env_vars
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
,
MutableScope
from
nni.nas.pytorch.mutator
import
Mutator
logger
=
logging
.
getLogger
(
__name__
)
NNI_GEN_SEARCH_SPACE
=
"NNI_GEN_SEARCH_SPACE"
LAYER_CHOICE
=
"layer_choice"
INPUT_CHOICE
=
"input_choice"
def
get_and_apply_next_architecture
(
model
):
"""
Wrapper of :class:`~nni.nas.pytorch.classic_nas.mutator.ClassicMutator` to make it more meaningful,
similar to ``get_next_parameter`` for HPO.
It will generate search space based on ``model``.
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
generating search space for the experiment.
If not, there are still two mode, one is nni experiment mode where users
use ``nnictl`` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
ClassicMutator
(
model
)
class
ClassicMutator
(
Mutator
):
"""
This mutator is to apply the architecture chosen from tuner.
It implements the forward function of LayerChoice and InputChoice,
to only activate the chosen ones.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
def
__init__
(
self
,
model
):
super
(
ClassicMutator
,
self
).
__init__
(
model
)
self
.
_chosen_arch
=
{}
self
.
_search_space
=
self
.
_generate_search_space
()
if
NNI_GEN_SEARCH_SPACE
in
os
.
environ
:
# dry run for only generating search space
self
.
_dump_search_space
(
os
.
environ
[
NNI_GEN_SEARCH_SPACE
])
sys
.
exit
(
0
)
if
trial_env_vars
.
NNI_PLATFORM
is
None
:
logger
.
warning
(
"This is in standalone mode, the chosen are the first one(s)."
)
self
.
_chosen_arch
=
self
.
_standalone_generate_chosen
()
else
:
# get chosen arch from tuner
self
.
_chosen_arch
=
nni
.
get_next_parameter
()
if
self
.
_chosen_arch
is
None
:
if
trial_env_vars
.
NNI_PLATFORM
==
"unittest"
:
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
logger
.
warning
(
"`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode."
)
self
.
_chosen_arch
=
self
.
_standalone_generate_chosen
()
else
:
raise
RuntimeError
(
"Chosen architecture is None. This may be a platform error."
)
self
.
reset
()
def
_sample_layer_choice
(
self
,
mutable
,
idx
,
value
,
search_space_item
):
"""
Convert layer choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
# doesn't support multihot for layer choice yet
onehot_list
=
[
False
]
*
len
(
mutable
)
assert
0
<=
idx
<
len
(
mutable
)
and
search_space_item
[
idx
]
==
value
,
\
"Index '{}' in search space '{}' is not '{}'"
.
format
(
idx
,
search_space_item
,
value
)
onehot_list
[
idx
]
=
True
return
torch
.
tensor
(
onehot_list
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
def
_sample_input_choice
(
self
,
mutable
,
idx
,
value
,
search_space_item
):
"""
Convert input choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
candidate_repr
=
search_space_item
[
"candidates"
]
multihot_list
=
[
False
]
*
mutable
.
n_candidates
for
i
,
v
in
zip
(
idx
,
value
):
assert
0
<=
i
<
mutable
.
n_candidates
and
candidate_repr
[
i
]
==
v
,
\
"Index '{}' in search space '{}' is not '{}'"
.
format
(
i
,
candidate_repr
,
v
)
assert
not
multihot_list
[
i
],
"'{}' is selected twice in '{}', which is not allowed."
.
format
(
i
,
idx
)
multihot_list
[
i
]
=
True
return
torch
.
tensor
(
multihot_list
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
def
sample_search
(
self
):
"""
See :meth:`sample_final`.
"""
return
self
.
sample_final
()
def
sample_final
(
self
):
"""
Convert the chosen arch and apply it on model.
"""
assert
set
(
self
.
_chosen_arch
.
keys
())
==
set
(
self
.
_search_space
.
keys
()),
\
"Unmatched keys, expected keys '{}' from search space, found '{}'."
.
format
(
self
.
_search_space
.
keys
(),
self
.
_chosen_arch
.
keys
())
result
=
dict
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
(
LayerChoice
,
InputChoice
)):
assert
mutable
.
key
in
self
.
_chosen_arch
,
\
"Expected '{}' in chosen arch, but not found."
.
format
(
mutable
.
key
)
data
=
self
.
_chosen_arch
[
mutable
.
key
]
assert
isinstance
(
data
,
dict
)
and
"_value"
in
data
and
"_idx"
in
data
,
\
"'{}' is not a valid choice."
.
format
(
data
)
if
isinstance
(
mutable
,
LayerChoice
):
result
[
mutable
.
key
]
=
self
.
_sample_layer_choice
(
mutable
,
data
[
"_idx"
],
data
[
"_value"
],
self
.
_search_space
[
mutable
.
key
][
"_value"
])
elif
isinstance
(
mutable
,
InputChoice
):
result
[
mutable
.
key
]
=
self
.
_sample_input_choice
(
mutable
,
data
[
"_idx"
],
data
[
"_value"
],
self
.
_search_space
[
mutable
.
key
][
"_value"
])
elif
isinstance
(
mutable
,
MutableScope
):
logger
.
info
(
"Mutable scope '%s' is skipped during parsing choices."
,
mutable
.
key
)
else
:
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
return
result
def
_standalone_generate_chosen
(
self
):
"""
Generate the chosen architecture for standalone mode,
i.e., choose the first one(s) for LayerChoice and InputChoice.
::
{ key_name: {"_value": "conv1",
"_idx": 0} }
{ key_name: {"_value": ["in1"],
"_idx": [0]} }
Returns
-------
dict
the chosen architecture
"""
chosen_arch
=
{}
for
key
,
val
in
self
.
_search_space
.
items
():
if
val
[
"_type"
]
==
LAYER_CHOICE
:
choices
=
val
[
"_value"
]
chosen_arch
[
key
]
=
{
"_value"
:
choices
[
0
],
"_idx"
:
0
}
elif
val
[
"_type"
]
==
INPUT_CHOICE
:
choices
=
val
[
"_value"
][
"candidates"
]
n_chosen
=
val
[
"_value"
][
"n_chosen"
]
if
n_chosen
is
None
:
n_chosen
=
len
(
choices
)
chosen_arch
[
key
]
=
{
"_value"
:
choices
[:
n_chosen
],
"_idx"
:
list
(
range
(
n_chosen
))}
else
:
raise
ValueError
(
"Unknown key '%s' and value '%s'."
%
(
key
,
val
))
return
chosen_arch
def
_generate_search_space
(
self
):
"""
Generate search space from mutables.
Here is the search space format:
::
{ key_name: {"_type": "layer_choice",
"_value": ["conv1", "conv2"]} }
{ key_name: {"_type": "input_choice",
"_value": {"candidates": ["in1", "in2"],
"n_chosen": 1}} }
Returns
-------
dict
the generated search space
"""
search_space
=
{}
for
mutable
in
self
.
mutables
:
# for now we only generate flattened search space
if
isinstance
(
mutable
,
LayerChoice
):
key
=
mutable
.
key
val
=
mutable
.
names
search_space
[
key
]
=
{
"_type"
:
LAYER_CHOICE
,
"_value"
:
val
}
elif
isinstance
(
mutable
,
InputChoice
):
key
=
mutable
.
key
search_space
[
key
]
=
{
"_type"
:
INPUT_CHOICE
,
"_value"
:
{
"candidates"
:
mutable
.
choose_from
,
"n_chosen"
:
mutable
.
n_chosen
}}
elif
isinstance
(
mutable
,
MutableScope
):
logger
.
info
(
"Mutable scope '%s' is skipped during generating search space."
,
mutable
.
key
)
else
:
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
return
search_space
def
_dump_search_space
(
self
,
file_path
):
with
open
(
file_path
,
"w"
)
as
ss_file
:
json
.
dump
(
self
.
_search_space
,
ss_file
,
sort_keys
=
True
,
indent
=
2
)
nni/algorithms/nas/pytorch/cream/trainer.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
copy
import
deepcopy
import
torch
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.utils
import
AverageMeterGroup
from
.utils
import
accuracy
,
reduce_metrics
logger
=
logging
.
getLogger
(
__name__
)
class
CreamSupernetTrainer
(
Trainer
):
"""
This trainer trains a supernet and output prioritized architectures that can be used for other tasks.
Parameters
----------
model : nn.Module
Model with mutables.
loss : callable
Called with logits and targets. Returns a loss tensor.
val_loss : callable
Called with logits and targets for validation only. Returns a loss tensor.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterablez
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
valid_loader : iterablez
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
mutator : Mutator
A mutator object that has been initialized with the model.
batch_size : int
Batch size.
log_frequency : int
Number of mini-batches to log metrics.
meta_sta_epoch : int
start epoch of using meta matching network to pick teacher architecture
update_iter : int
interval of updating meta matching networks
slices : int
batch size of mini training data in the process of training meta matching network
pool_size : int
board size
pick_method : basestring
how to pick teacher network
choice_num : int
number of operations in supernet
sta_num : int
layer number of each stage in supernet (5 stage in supernet)
acc_gap : int
maximum accuracy improvement to omit the limitation of flops
flops_dict : Dict
dictionary of each layer's operations in supernet
flops_fixed : int
flops of fixed part in supernet
local_rank : int
index of current rank
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
def
__init__
(
self
,
model
,
loss
,
val_loss
,
optimizer
,
num_epochs
,
train_loader
,
valid_loader
,
mutator
=
None
,
batch_size
=
64
,
log_frequency
=
None
,
meta_sta_epoch
=
20
,
update_iter
=
200
,
slices
=
2
,
pool_size
=
10
,
pick_method
=
'meta'
,
choice_num
=
6
,
sta_num
=
(
4
,
4
,
4
,
4
,
4
),
acc_gap
=
5
,
flops_dict
=
None
,
flops_fixed
=
0
,
local_rank
=
0
,
callbacks
=
None
):
assert
torch
.
cuda
.
is_available
()
super
(
CreamSupernetTrainer
,
self
).
__init__
(
model
,
mutator
,
loss
,
None
,
optimizer
,
num_epochs
,
None
,
None
,
batch_size
,
None
,
None
,
log_frequency
,
callbacks
)
self
.
model
=
model
self
.
loss
=
loss
self
.
val_loss
=
val_loss
self
.
train_loader
=
train_loader
self
.
valid_loader
=
valid_loader
self
.
log_frequency
=
log_frequency
self
.
batch_size
=
batch_size
self
.
optimizer
=
optimizer
self
.
model
=
model
self
.
loss
=
loss
self
.
num_epochs
=
num_epochs
self
.
meta_sta_epoch
=
meta_sta_epoch
self
.
update_iter
=
update_iter
self
.
slices
=
slices
self
.
pick_method
=
pick_method
self
.
pool_size
=
pool_size
self
.
local_rank
=
local_rank
self
.
choice_num
=
choice_num
self
.
sta_num
=
sta_num
self
.
acc_gap
=
acc_gap
self
.
flops_dict
=
flops_dict
self
.
flops_fixed
=
flops_fixed
self
.
current_student_arch
=
None
self
.
current_teacher_arch
=
None
self
.
main_proc
=
(
local_rank
==
0
)
self
.
current_epoch
=
0
self
.
prioritized_board
=
[]
# size of prioritized board
def
_board_size
(
self
):
return
len
(
self
.
prioritized_board
)
# select teacher architecture according to the logit difference
def
_select_teacher
(
self
):
self
.
_replace_mutator_cand
(
self
.
current_student_arch
)
if
self
.
pick_method
==
'top1'
:
meta_value
,
teacher_cand
=
0.5
,
sorted
(
self
.
prioritized_board
,
reverse
=
True
)[
0
][
3
]
elif
self
.
pick_method
==
'meta'
:
meta_value
,
cand_idx
,
teacher_cand
=
-
1000000000
,
-
1
,
None
for
now_idx
,
item
in
enumerate
(
self
.
prioritized_board
):
inputx
=
item
[
4
]
output
=
torch
.
nn
.
functional
.
softmax
(
self
.
model
(
inputx
),
dim
=
1
)
weight
=
self
.
model
.
module
.
forward_meta
(
output
-
item
[
5
])
if
weight
>
meta_value
:
meta_value
=
weight
cand_idx
=
now_idx
teacher_cand
=
self
.
prioritized_board
[
cand_idx
][
3
]
assert
teacher_cand
is
not
None
meta_value
=
torch
.
nn
.
functional
.
sigmoid
(
-
weight
)
else
:
raise
ValueError
(
'Method Not supported'
)
return
meta_value
,
teacher_cand
# check whether to update prioritized board
def
_isUpdateBoard
(
self
,
prec1
,
flops
):
if
self
.
current_epoch
<=
self
.
meta_sta_epoch
:
return
False
if
len
(
self
.
prioritized_board
)
<
self
.
pool_size
:
return
True
if
prec1
>
self
.
prioritized_board
[
-
1
][
1
]
+
self
.
acc_gap
:
return
True
if
prec1
>
self
.
prioritized_board
[
-
1
][
1
]
and
flops
<
self
.
prioritized_board
[
-
1
][
2
]:
return
True
return
False
# update prioritized board
def
_update_prioritized_board
(
self
,
inputs
,
teacher_output
,
outputs
,
prec1
,
flops
):
if
self
.
_isUpdateBoard
(
prec1
,
flops
):
val_prec1
=
prec1
training_data
=
deepcopy
(
inputs
[:
self
.
slices
].
detach
())
if
len
(
self
.
prioritized_board
)
==
0
:
features
=
deepcopy
(
outputs
[:
self
.
slices
].
detach
())
else
:
features
=
deepcopy
(
teacher_output
[:
self
.
slices
].
detach
())
self
.
prioritized_board
.
append
(
(
val_prec1
,
prec1
,
flops
,
self
.
current_student_arch
,
training_data
,
torch
.
nn
.
functional
.
softmax
(
features
,
dim
=
1
)))
self
.
prioritized_board
=
sorted
(
self
.
prioritized_board
,
reverse
=
True
)
if
len
(
self
.
prioritized_board
)
>
self
.
pool_size
:
del
self
.
prioritized_board
[
-
1
]
# only update student network weights
def
_update_student_weights_only
(
self
,
grad_1
):
for
weight
,
grad_item
in
zip
(
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
),
grad_1
):
weight
.
grad
=
grad_item
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
),
1
)
self
.
optimizer
.
step
()
for
weight
,
grad_item
in
zip
(
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
),
grad_1
):
del
weight
.
grad
# only update meta networks weights
def
_update_meta_weights_only
(
self
,
teacher_cand
,
grad_teacher
):
for
weight
,
grad_item
in
zip
(
self
.
model
.
module
.
rand_parameters
(
teacher_cand
,
self
.
pick_method
==
'meta'
),
grad_teacher
):
weight
.
grad
=
grad_item
# clip gradients
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
,
self
.
pick_method
==
'meta'
),
1
)
self
.
optimizer
.
step
()
for
weight
,
grad_item
in
zip
(
self
.
model
.
module
.
rand_parameters
(
teacher_cand
,
self
.
pick_method
==
'meta'
),
grad_teacher
):
del
weight
.
grad
# simulate sgd updating
def
_simulate_sgd_update
(
self
,
w
,
g
,
optimizer
):
return
g
*
optimizer
.
param_groups
[
-
1
][
'lr'
]
+
w
# split training images into several slices
def
_get_minibatch_input
(
self
,
input
):
# pylint: disable=redefined-builtin
slice
=
self
.
slices
# pylint: disable=redefined-builtin
x
=
deepcopy
(
input
[:
slice
].
clone
().
detach
())
return
x
# calculate 1st gradient of student architectures
def
_calculate_1st_gradient
(
self
,
kd_loss
):
self
.
optimizer
.
zero_grad
()
grad
=
torch
.
autograd
.
grad
(
kd_loss
,
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
),
create_graph
=
True
)
return
grad
# calculate 2nd gradient of meta networks
def
_calculate_2nd_gradient
(
self
,
validation_loss
,
teacher_cand
,
students_weight
):
self
.
optimizer
.
zero_grad
()
grad_student_val
=
torch
.
autograd
.
grad
(
validation_loss
,
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
),
retain_graph
=
True
)
grad_teacher
=
torch
.
autograd
.
grad
(
students_weight
[
0
],
self
.
model
.
module
.
rand_parameters
(
teacher_cand
,
self
.
pick_method
==
'meta'
),
grad_outputs
=
grad_student_val
)
return
grad_teacher
# forward training data
def
_forward_training
(
self
,
x
,
meta_value
):
self
.
_replace_mutator_cand
(
self
.
current_student_arch
)
output
=
self
.
model
(
x
)
with
torch
.
no_grad
():
self
.
_replace_mutator_cand
(
self
.
current_teacher_arch
)
teacher_output
=
self
.
model
(
x
)
soft_label
=
torch
.
nn
.
functional
.
softmax
(
teacher_output
,
dim
=
1
)
kd_loss
=
meta_value
*
\
self
.
_cross_entropy_loss_with_soft_target
(
output
,
soft_label
)
return
kd_loss
# calculate soft target loss
def
_cross_entropy_loss_with_soft_target
(
self
,
pred
,
soft_target
):
logsoftmax
=
torch
.
nn
.
LogSoftmax
()
return
torch
.
mean
(
torch
.
sum
(
-
soft_target
*
logsoftmax
(
pred
),
1
))
# forward validation data
def
_forward_validation
(
self
,
input
,
target
):
# pylint: disable=redefined-builtin
slice
=
self
.
slices
# pylint: disable=redefined-builtin
x
=
input
[
slice
:
slice
*
2
].
clone
()
self
.
_replace_mutator_cand
(
self
.
current_student_arch
)
output_2
=
self
.
model
(
x
)
validation_loss
=
self
.
loss
(
output_2
,
target
[
slice
:
slice
*
2
])
return
validation_loss
def
_isUpdateMeta
(
self
,
batch_idx
):
isUpdate
=
True
isUpdate
&=
(
self
.
current_epoch
>
self
.
meta_sta_epoch
)
isUpdate
&=
(
batch_idx
>
0
)
isUpdate
&=
(
batch_idx
%
self
.
update_iter
==
0
)
isUpdate
&=
(
self
.
_board_size
()
>
0
)
return
isUpdate
def
_replace_mutator_cand
(
self
,
cand
):
self
.
mutator
.
_cache
=
cand
# update meta matching networks
def
_run_update
(
self
,
input
,
target
,
batch_idx
):
# pylint: disable=redefined-builtin
if
self
.
_isUpdateMeta
(
batch_idx
):
x
=
self
.
_get_minibatch_input
(
input
)
meta_value
,
teacher_cand
=
self
.
_select_teacher
()
kd_loss
=
self
.
_forward_training
(
x
,
meta_value
)
# calculate 1st gradient
grad_1st
=
self
.
_calculate_1st_gradient
(
kd_loss
)
# simulate updated student weights
students_weight
=
[
self
.
_simulate_sgd_update
(
p
,
grad_item
,
self
.
optimizer
)
for
p
,
grad_item
in
zip
(
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
),
grad_1st
)]
# update student weights
self
.
_update_student_weights_only
(
grad_1st
)
validation_loss
=
self
.
_forward_validation
(
input
,
target
)
# calculate 2nd gradient
grad_teacher
=
self
.
_calculate_2nd_gradient
(
validation_loss
,
teacher_cand
,
students_weight
)
# update meta matching networks
self
.
_update_meta_weights_only
(
teacher_cand
,
grad_teacher
)
# delete internal variants
del
grad_teacher
,
grad_1st
,
x
,
validation_loss
,
kd_loss
,
students_weight
def
_get_cand_flops
(
self
,
cand
):
flops
=
0
for
block_id
,
block
in
enumerate
(
cand
):
if
block
==
'LayerChoice1'
or
block_id
==
'LayerChoice23'
:
continue
for
idx
,
choice
in
enumerate
(
cand
[
block
]):
flops
+=
self
.
flops_dict
[
block_id
][
idx
]
*
(
1
if
choice
else
0
)
return
flops
+
self
.
flops_fixed
def
train_one_epoch
(
self
,
epoch
):
self
.
current_epoch
=
epoch
meters
=
AverageMeterGroup
()
self
.
steps_per_epoch
=
len
(
self
.
train_loader
)
for
step
,
(
input_data
,
target
)
in
enumerate
(
self
.
train_loader
):
self
.
mutator
.
reset
()
self
.
current_student_arch
=
self
.
mutator
.
_cache
input_data
,
target
=
input_data
.
cuda
(),
target
.
cuda
()
# calculate flops of current architecture
cand_flops
=
self
.
_get_cand_flops
(
self
.
mutator
.
_cache
)
# update meta matching network
self
.
_run_update
(
input_data
,
target
,
step
)
if
self
.
_board_size
()
>
0
:
# select teacher architecture
meta_value
,
teacher_cand
=
self
.
_select_teacher
()
self
.
current_teacher_arch
=
teacher_cand
# forward supernet
if
self
.
_board_size
()
==
0
or
epoch
<=
self
.
meta_sta_epoch
:
self
.
_replace_mutator_cand
(
self
.
current_student_arch
)
output
=
self
.
model
(
input_data
)
loss
=
self
.
loss
(
output
,
target
)
kd_loss
,
teacher_output
,
teacher_cand
=
None
,
None
,
None
else
:
self
.
_replace_mutator_cand
(
self
.
current_student_arch
)
output
=
self
.
model
(
input_data
)
gt_loss
=
self
.
loss
(
output
,
target
)
with
torch
.
no_grad
():
self
.
_replace_mutator_cand
(
self
.
current_teacher_arch
)
teacher_output
=
self
.
model
(
input_data
).
detach
()
soft_label
=
torch
.
nn
.
functional
.
softmax
(
teacher_output
,
dim
=
1
)
kd_loss
=
self
.
_cross_entropy_loss_with_soft_target
(
output
,
soft_label
)
loss
=
(
meta_value
*
kd_loss
+
(
2
-
meta_value
)
*
gt_loss
)
/
2
# update network
self
.
optimizer
.
zero_grad
()
loss
.
backward
()
self
.
optimizer
.
step
()
# update metrics
prec1
,
prec5
=
accuracy
(
output
,
target
,
topk
=
(
1
,
5
))
metrics
=
{
"prec1"
:
prec1
,
"prec5"
:
prec5
,
"loss"
:
loss
}
metrics
=
reduce_metrics
(
metrics
)
meters
.
update
(
metrics
)
# update prioritized board
self
.
_update_prioritized_board
(
input_data
,
teacher_output
,
output
,
metrics
[
'prec1'
],
cand_flops
)
if
self
.
main_proc
and
(
step
%
self
.
log_frequency
==
0
or
step
+
1
==
self
.
steps_per_epoch
):
logger
.
info
(
"Epoch [%d/%d] Step [%d/%d] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
+
1
,
len
(
self
.
train_loader
),
meters
)
if
self
.
main_proc
and
self
.
num_epochs
==
epoch
+
1
:
for
idx
,
i
in
enumerate
(
self
.
prioritized_board
):
logger
.
info
(
"No.%s %s"
,
idx
,
i
[:
4
])
def
validate_one_epoch
(
self
,
epoch
):
self
.
model
.
eval
()
meters
=
AverageMeterGroup
()
with
torch
.
no_grad
():
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
valid_loader
):
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
loss
=
self
.
val_loss
(
logits
,
y
)
prec1
,
prec5
=
accuracy
(
logits
,
y
,
topk
=
(
1
,
5
))
metrics
=
{
"prec1"
:
prec1
,
"prec5"
:
prec5
,
"loss"
:
loss
}
metrics
=
reduce_metrics
(
metrics
)
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"Epoch [%s/%s] Validation Step [%s/%s] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
+
1
,
len
(
self
.
valid_loader
),
meters
)
nni/algorithms/nas/pytorch/cream/utils.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
torch.distributed
as
dist
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
""" Computes the precision@k for the specified values of k """
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
# one-hot case
if
target
.
ndimension
()
>
1
:
target
=
target
.
max
(
1
)[
1
]
correct
=
pred
.
eq
(
target
.
reshape
(
1
,
-
1
).
expand_as
(
pred
))
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
)
res
.
append
(
correct_k
.
mul_
(
1.0
/
batch_size
))
return
res
def
reduce_metrics
(
metrics
):
return
{
k
:
reduce_tensor
(
v
).
item
()
for
k
,
v
in
metrics
.
items
()}
def
reduce_tensor
(
tensor
):
rt
=
tensor
.
clone
()
dist
.
all_reduce
(
rt
,
op
=
dist
.
ReduceOp
.
SUM
)
rt
/=
float
(
os
.
environ
[
"WORLD_SIZE"
])
return
rt
nni/algorithms/nas/pytorch/darts/mutator.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
_logger
=
logging
.
getLogger
(
__name__
)
class
DartsMutator
(
Mutator
):
"""
Connects the model in a DARTS (differentiable) way.
An extra connection is automatically inserted for each LayerChoice, when this connection is selected, there is no
op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false``
(not chosen).
All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based
on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice
will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0.
It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the
value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will
be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully.
Attributes
----------
choices: ParameterDict
dict that maps keys of LayerChoices to weighted-connection float tensors.
"""
def
__init__
(
self
,
model
):
super
().
__init__
(
model
)
self
.
choices
=
nn
.
ParameterDict
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
self
.
choices
[
mutable
.
key
]
=
nn
.
Parameter
(
1.0E-3
*
torch
.
randn
(
mutable
.
length
+
1
))
def
device
(
self
):
for
v
in
self
.
choices
.
values
():
return
v
.
device
def
sample_search
(
self
):
result
=
dict
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
result
[
mutable
.
key
]
=
F
.
softmax
(
self
.
choices
[
mutable
.
key
],
dim
=-
1
)[:
-
1
]
elif
isinstance
(
mutable
,
InputChoice
):
result
[
mutable
.
key
]
=
torch
.
ones
(
mutable
.
n_candidates
,
dtype
=
torch
.
bool
,
device
=
self
.
device
())
return
result
def
sample_final
(
self
):
result
=
dict
()
edges_max
=
dict
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
max_val
,
index
=
torch
.
max
(
F
.
softmax
(
self
.
choices
[
mutable
.
key
],
dim
=-
1
)[:
-
1
],
0
)
edges_max
[
mutable
.
key
]
=
max_val
result
[
mutable
.
key
]
=
F
.
one_hot
(
index
,
num_classes
=
len
(
mutable
)).
view
(
-
1
).
bool
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
InputChoice
):
if
mutable
.
n_chosen
is
not
None
:
weights
=
[]
for
src_key
in
mutable
.
choose_from
:
if
src_key
not
in
edges_max
:
_logger
.
warning
(
"InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs."
,
mutable
.
key
)
weights
.
append
(
edges_max
.
get
(
src_key
,
0.
))
weights
=
torch
.
tensor
(
weights
)
# pylint: disable=not-callable
_
,
topk_edge_indices
=
torch
.
topk
(
weights
,
mutable
.
n_chosen
)
selected_multihot
=
[]
for
i
,
src_key
in
enumerate
(
mutable
.
choose_from
):
if
i
not
in
topk_edge_indices
and
src_key
in
result
:
# If an edge is never selected, there is no need to calculate any op on this edge.
# This is to eliminate redundant calculation.
result
[
src_key
]
=
torch
.
zeros_like
(
result
[
src_key
])
selected_multihot
.
append
(
i
in
topk_edge_indices
)
result
[
mutable
.
key
]
=
torch
.
tensor
(
selected_multihot
,
dtype
=
torch
.
bool
,
device
=
self
.
device
())
# pylint: disable=not-callable
else
:
result
[
mutable
.
key
]
=
torch
.
ones
(
mutable
.
n_candidates
,
dtype
=
torch
.
bool
,
device
=
self
.
device
())
# pylint: disable=not-callable
return
result
nni/algorithms/nas/pytorch/darts/trainer.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
logging
import
torch
import
torch.nn
as
nn
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.utils
import
AverageMeterGroup
from
.mutator
import
DartsMutator
logger
=
logging
.
getLogger
(
__name__
)
class
DartsTrainer
(
Trainer
):
"""
DARTS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : DartsMutator
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
def
__init__
(
self
,
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
arc_learning_rate
=
3.0E-4
,
unrolled
=
False
):
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
DartsMutator
(
model
),
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
self
.
ctrl_optim
=
torch
.
optim
.
Adam
(
self
.
mutator
.
parameters
(),
arc_learning_rate
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
self
.
unrolled
=
unrolled
n_train
=
len
(
self
.
dataset_train
)
split
=
n_train
//
2
indices
=
list
(
range
(
n_train
))
train_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[:
split
])
valid_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[
split
:])
self
.
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
batch_size
=
batch_size
,
sampler
=
train_sampler
,
num_workers
=
workers
)
self
.
valid_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
batch_size
=
batch_size
,
sampler
=
valid_sampler
,
num_workers
=
workers
)
self
.
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_valid
,
batch_size
=
batch_size
,
num_workers
=
workers
)
def
train_one_epoch
(
self
,
epoch
):
self
.
model
.
train
()
self
.
mutator
.
train
()
meters
=
AverageMeterGroup
()
for
step
,
((
trn_X
,
trn_y
),
(
val_X
,
val_y
))
in
enumerate
(
zip
(
self
.
train_loader
,
self
.
valid_loader
)):
trn_X
,
trn_y
=
trn_X
.
to
(
self
.
device
),
trn_y
.
to
(
self
.
device
)
val_X
,
val_y
=
val_X
.
to
(
self
.
device
),
val_y
.
to
(
self
.
device
)
# phase 1. architecture step
self
.
ctrl_optim
.
zero_grad
()
if
self
.
unrolled
:
self
.
_unrolled_backward
(
trn_X
,
trn_y
,
val_X
,
val_y
)
else
:
self
.
_backward
(
val_X
,
val_y
)
self
.
ctrl_optim
.
step
()
# phase 2: child network step
self
.
optimizer
.
zero_grad
()
logits
,
loss
=
self
.
_logits_and_loss
(
trn_X
,
trn_y
)
loss
.
backward
()
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
parameters
(),
5.
)
# gradient clipping
self
.
optimizer
.
step
()
metrics
=
self
.
metrics
(
logits
,
trn_y
)
metrics
[
"loss"
]
=
loss
.
item
()
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
+
1
,
len
(
self
.
train_loader
),
meters
)
def
validate_one_epoch
(
self
,
epoch
):
self
.
model
.
eval
()
self
.
mutator
.
eval
()
meters
=
AverageMeterGroup
()
with
torch
.
no_grad
():
self
.
mutator
.
reset
()
for
step
,
(
X
,
y
)
in
enumerate
(
self
.
test_loader
):
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
logits
=
self
.
model
(
X
)
metrics
=
self
.
metrics
(
logits
,
y
)
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
+
1
,
len
(
self
.
test_loader
),
meters
)
def
_logits_and_loss
(
self
,
X
,
y
):
self
.
mutator
.
reset
()
logits
=
self
.
model
(
X
)
loss
=
self
.
loss
(
logits
,
y
)
self
.
_write_graph_status
()
return
logits
,
loss
def
_backward
(
self
,
val_X
,
val_y
):
"""
Simple backward with gradient descent
"""
_
,
loss
=
self
.
_logits_and_loss
(
val_X
,
val_y
)
loss
.
backward
()
def
_unrolled_backward
(
self
,
trn_X
,
trn_y
,
val_X
,
val_y
):
"""
Compute unrolled loss and backward its gradients
"""
backup_params
=
copy
.
deepcopy
(
tuple
(
self
.
model
.
parameters
()))
# do virtual step on training data
lr
=
self
.
optimizer
.
param_groups
[
0
][
"lr"
]
momentum
=
self
.
optimizer
.
param_groups
[
0
][
"momentum"
]
weight_decay
=
self
.
optimizer
.
param_groups
[
0
][
"weight_decay"
]
self
.
_compute_virtual_model
(
trn_X
,
trn_y
,
lr
,
momentum
,
weight_decay
)
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_
,
loss
=
self
.
_logits_and_loss
(
val_X
,
val_y
)
w_model
,
w_ctrl
=
tuple
(
self
.
model
.
parameters
()),
tuple
(
self
.
mutator
.
parameters
())
w_grads
=
torch
.
autograd
.
grad
(
loss
,
w_model
+
w_ctrl
)
d_model
,
d_ctrl
=
w_grads
[:
len
(
w_model
)],
w_grads
[
len
(
w_model
):]
# compute hessian and final gradients
hessian
=
self
.
_compute_hessian
(
backup_params
,
d_model
,
trn_X
,
trn_y
)
with
torch
.
no_grad
():
for
param
,
d
,
h
in
zip
(
w_ctrl
,
d_ctrl
,
hessian
):
# gradient = dalpha - lr * hessian
param
.
grad
=
d
-
lr
*
h
# restore weights
self
.
_restore_weights
(
backup_params
)
def
_compute_virtual_model
(
self
,
X
,
y
,
lr
,
momentum
,
weight_decay
):
"""
Compute unrolled weights w`
"""
# don't need zero_grad, using autograd to calculate gradients
_
,
loss
=
self
.
_logits_and_loss
(
X
,
y
)
gradients
=
torch
.
autograd
.
grad
(
loss
,
self
.
model
.
parameters
())
with
torch
.
no_grad
():
for
w
,
g
in
zip
(
self
.
model
.
parameters
(),
gradients
):
m
=
self
.
optimizer
.
state
[
w
].
get
(
"momentum_buffer"
,
0.
)
w
=
w
-
lr
*
(
momentum
*
m
+
g
+
weight_decay
*
w
)
def
_restore_weights
(
self
,
backup_params
):
with
torch
.
no_grad
():
for
param
,
backup
in
zip
(
self
.
model
.
parameters
(),
backup_params
):
param
.
copy_
(
backup
)
def
_compute_hessian
(
self
,
backup_params
,
dw
,
trn_X
,
trn_y
):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self
.
_restore_weights
(
backup_params
)
norm
=
torch
.
cat
([
w
.
view
(
-
1
)
for
w
in
dw
]).
norm
()
eps
=
0.01
/
norm
if
norm
<
1E-8
:
logger
.
warning
(
"In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f."
,
norm
.
item
())
dalphas
=
[]
for
e
in
[
eps
,
-
2.
*
eps
]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with
torch
.
no_grad
():
for
p
,
d
in
zip
(
self
.
model
.
parameters
(),
dw
):
p
+=
e
*
d
_
,
loss
=
self
.
_logits_and_loss
(
trn_X
,
trn_y
)
dalphas
.
append
(
torch
.
autograd
.
grad
(
loss
,
self
.
mutator
.
parameters
()))
dalpha_pos
,
dalpha_neg
=
dalphas
# dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian
=
[(
p
-
n
)
/
(
2.
*
eps
)
for
p
,
n
in
zip
(
dalpha_pos
,
dalpha_neg
)]
return
hessian
nni/algorithms/nas/pytorch/enas/trainer.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
itertools
import
cycle
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.utils
import
AverageMeterGroup
,
to_device
from
.mutator
import
EnasMutator
logger
=
logging
.
getLogger
(
__name__
)
class
EnasTrainer
(
Trainer
):
"""
ENAS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : EnasMutator
Use when customizing your own mutator or a mutator with customized parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
child_steps : int
How many mini-batches for model training per epoch.
mutator_lr : float
Learning rate for RL controller.
mutator_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
mutator_steps : int
Number of mini-batches for each epoch of RL controller learning.
aux_weight : float
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
test_arc_per_epoch : int
How many architectures are chosen for direct test after each epoch.
"""
def
__init__
(
self
,
model
,
loss
,
metrics
,
reward_function
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
child_steps
=
500
,
mutator_lr
=
0.00035
,
mutator_steps_aggregate
=
20
,
mutator_steps
=
50
,
aux_weight
=
0.4
,
test_arc_per_epoch
=
1
):
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
EnasMutator
(
model
),
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
self
.
reward_function
=
reward_function
self
.
mutator_optim
=
optim
.
Adam
(
self
.
mutator
.
parameters
(),
lr
=
mutator_lr
)
self
.
batch_size
=
batch_size
self
.
workers
=
workers
self
.
entropy_weight
=
entropy_weight
self
.
skip_weight
=
skip_weight
self
.
baseline_decay
=
baseline_decay
self
.
baseline
=
0.
self
.
mutator_steps_aggregate
=
mutator_steps_aggregate
self
.
mutator_steps
=
mutator_steps
self
.
child_steps
=
child_steps
self
.
aux_weight
=
aux_weight
self
.
test_arc_per_epoch
=
test_arc_per_epoch
self
.
init_dataloader
()
def
init_dataloader
(
self
):
n_train
=
len
(
self
.
dataset_train
)
split
=
n_train
//
10
indices
=
list
(
range
(
n_train
))
train_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[:
-
split
])
valid_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[
-
split
:])
self
.
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
batch_size
=
self
.
batch_size
,
sampler
=
train_sampler
,
num_workers
=
self
.
workers
)
self
.
valid_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
batch_size
=
self
.
batch_size
,
sampler
=
valid_sampler
,
num_workers
=
self
.
workers
)
self
.
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_valid
,
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
workers
)
self
.
train_loader
=
cycle
(
self
.
train_loader
)
self
.
valid_loader
=
cycle
(
self
.
valid_loader
)
def
train_one_epoch
(
self
,
epoch
):
# Sample model and train
self
.
model
.
train
()
self
.
mutator
.
eval
()
meters
=
AverageMeterGroup
()
for
step
in
range
(
1
,
self
.
child_steps
+
1
):
x
,
y
=
next
(
self
.
train_loader
)
x
,
y
=
to_device
(
x
,
self
.
device
),
to_device
(
y
,
self
.
device
)
self
.
optimizer
.
zero_grad
()
with
torch
.
no_grad
():
self
.
mutator
.
reset
()
self
.
_write_graph_status
()
logits
=
self
.
model
(
x
)
if
isinstance
(
logits
,
tuple
):
logits
,
aux_logits
=
logits
aux_loss
=
self
.
loss
(
aux_logits
,
y
)
else
:
aux_loss
=
0.
metrics
=
self
.
metrics
(
logits
,
y
)
loss
=
self
.
loss
(
logits
,
y
)
loss
=
loss
+
self
.
aux_weight
*
aux_loss
loss
.
backward
()
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
parameters
(),
5.
)
self
.
optimizer
.
step
()
metrics
[
"loss"
]
=
loss
.
item
()
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"Model Epoch [%d/%d] Step [%d/%d] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
,
self
.
child_steps
,
meters
)
# Train sampler (mutator)
self
.
model
.
eval
()
self
.
mutator
.
train
()
meters
=
AverageMeterGroup
()
for
mutator_step
in
range
(
1
,
self
.
mutator_steps
+
1
):
self
.
mutator_optim
.
zero_grad
()
for
step
in
range
(
1
,
self
.
mutator_steps_aggregate
+
1
):
x
,
y
=
next
(
self
.
valid_loader
)
x
,
y
=
to_device
(
x
,
self
.
device
),
to_device
(
y
,
self
.
device
)
self
.
mutator
.
reset
()
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
self
.
_write_graph_status
()
metrics
=
self
.
metrics
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
if
self
.
entropy_weight
:
reward
+=
self
.
entropy_weight
*
self
.
mutator
.
sample_entropy
.
item
()
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
loss
=
self
.
mutator
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
if
self
.
skip_weight
:
loss
+=
self
.
skip_weight
*
self
.
mutator
.
sample_skip_penalty
metrics
[
"reward"
]
=
reward
metrics
[
"loss"
]
=
loss
.
item
()
metrics
[
"ent"
]
=
self
.
mutator
.
sample_entropy
.
item
()
metrics
[
"log_prob"
]
=
self
.
mutator
.
sample_log_prob
.
item
()
metrics
[
"baseline"
]
=
self
.
baseline
metrics
[
"skip"
]
=
self
.
mutator
.
sample_skip_penalty
loss
/=
self
.
mutator_steps_aggregate
loss
.
backward
()
meters
.
update
(
metrics
)
cur_step
=
step
+
(
mutator_step
-
1
)
*
self
.
mutator_steps_aggregate
if
self
.
log_frequency
is
not
None
and
cur_step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s"
,
epoch
+
1
,
self
.
num_epochs
,
mutator_step
,
self
.
mutator_steps
,
step
,
self
.
mutator_steps_aggregate
,
meters
)
nn
.
utils
.
clip_grad_norm_
(
self
.
mutator
.
parameters
(),
5.
)
self
.
mutator_optim
.
step
()
def
validate_one_epoch
(
self
,
epoch
):
with
torch
.
no_grad
():
for
arc_id
in
range
(
self
.
test_arc_per_epoch
):
meters
=
AverageMeterGroup
()
for
x
,
y
in
self
.
test_loader
:
x
,
y
=
to_device
(
x
,
self
.
device
),
to_device
(
y
,
self
.
device
)
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
if
isinstance
(
logits
,
tuple
):
logits
,
_
=
logits
metrics
=
self
.
metrics
(
logits
,
y
)
loss
=
self
.
loss
(
logits
,
y
)
metrics
[
"loss"
]
=
loss
.
item
()
meters
.
update
(
metrics
)
logger
.
info
(
"Test Epoch [%d/%d] Arc [%d/%d] Summary %s"
,
epoch
+
1
,
self
.
num_epochs
,
arc_id
+
1
,
self
.
test_arc_per_epoch
,
meters
.
summary
())
nni/algorithms/nas/pytorch/fbnet/__init__.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
absolute_import
from
.mutator
import
FBNetMutator
# noqa: F401
from
.trainer
import
FBNetTrainer
# noqa: F401
from
.utils
import
(
# noqa: F401
LookUpTable
,
NASConfig
,
RegularizerLoss
,
model_init
,
supernet_sample
,
)
nni/algorithms/nas/pytorch/fbnet/mutator.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
absolute_import
,
division
,
print_function
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
import
numpy
as
np
from
nni.nas.pytorch.base_mutator
import
BaseMutator
from
nni.nas.pytorch.mutables
import
LayerChoice
class
MixedOp
(
nn
.
Module
):
"""
This class is to instantiate and manage info of one LayerChoice.
It includes architecture weights and member functions for the weights.
"""
def
__init__
(
self
,
mutable
,
latency
):
"""
Parameters
----------
mutable : LayerChoice
A LayerChoice in user model
latency : List
performance cost for each op in mutable
"""
super
(
MixedOp
,
self
).
__init__
()
self
.
latency
=
latency
n_choices
=
len
(
mutable
)
self
.
path_alpha
=
nn
.
Parameter
(
torch
.
FloatTensor
([
1.0
/
n_choices
for
i
in
range
(
n_choices
)])
)
self
.
path_alpha
.
requires_grad
=
False
self
.
temperature
=
1.0
def
get_path_alpha
(
self
):
"""Return the architecture parameter."""
return
self
.
path_alpha
def
get_weighted_latency
(
self
):
"""Return the weighted perf_cost of current mutable."""
soft_masks
=
self
.
probs_over_ops
()
weighted_latency
=
sum
(
m
*
l
for
m
,
l
in
zip
(
soft_masks
,
self
.
latency
))
return
weighted_latency
def
set_temperature
(
self
,
temperature
):
"""
Set the annealed temperature for gumbel softmax.
Parameters
----------
temperature : float
The annealed temperature for gumbel softmax
"""
self
.
temperature
=
temperature
def
to_requires_grad
(
self
):
"""Enable gradient calculation."""
self
.
path_alpha
.
requires_grad
=
True
def
to_disable_grad
(
self
):
"""Disable gradient calculation."""
self
.
path_alpha
.
requires_grad
=
False
def
probs_over_ops
(
self
):
"""Apply gumbel softmax to generate probability distribution."""
return
F
.
gumbel_softmax
(
self
.
path_alpha
,
self
.
temperature
)
def
forward
(
self
,
mutable
,
x
):
"""
Define forward of LayerChoice.
Parameters
----------
mutable : LayerChoice
this layer's mutable
x : tensor
inputs of this layer, only support one input
Returns
-------
output: tensor
output of this layer
"""
candidate_ops
=
list
(
mutable
)
soft_masks
=
self
.
probs_over_ops
()
output
=
sum
(
m
*
op
(
x
)
for
m
,
op
in
zip
(
soft_masks
,
candidate_ops
))
return
output
@
property
def
chosen_index
(
self
):
"""
choose the op with max prob
Returns
-------
int
index of the chosen one
"""
alphas
=
self
.
path_alpha
.
data
.
detach
().
cpu
().
numpy
()
index
=
int
(
np
.
argmax
(
alphas
))
return
index
class
FBNetMutator
(
BaseMutator
):
"""
This mutator initializes and operates all the LayerChoices of the supernet.
It is for the related trainer to control the training flow of LayerChoices,
coordinating with whole training process.
"""
def
__init__
(
self
,
model
,
lookup_table
):
"""
Init a MixedOp instance for each mutable i.e., LayerChoice.
And register the instantiated MixedOp in corresponding LayerChoice.
If does not register it in LayerChoice, DataParallel does'nt work then,
for architecture weights are not included in the DataParallel model.
When MixedOPs are registered, we use ```requires_grad``` to control
whether calculate gradients of architecture weights.
Parameters
----------
model : pytorch model
The model that users want to tune,
it includes search space defined with nni nas apis
lookup_table : class
lookup table object to manage model space information,
including candidate ops for each stage as the model space,
input channels/output channels/stride/fm_size as the layer config,
and the performance information for perf_cost accumulation.
"""
super
(
FBNetMutator
,
self
).
__init__
(
model
)
self
.
mutable_list
=
[]
# Collect the op names of the candidate ops within each mutable
ops_names_mutable
=
dict
()
left
=
0
right
=
1
for
stage_name
in
lookup_table
.
layer_num
:
right
=
lookup_table
.
layer_num
[
stage_name
]
stage_ops
=
lookup_table
.
lut_ops
[
stage_name
]
ops_names
=
[
op_name
for
op_name
in
stage_ops
]
for
i
in
range
(
left
,
left
+
right
):
ops_names_mutable
[
i
]
=
ops_names
left
+=
right
# Create the mixed op
for
i
,
mutable
in
enumerate
(
self
.
undedup_mutables
):
ops_names
=
ops_names_mutable
[
i
]
latency_mutable
=
lookup_table
.
lut_perf
[
i
]
latency
=
[
latency_mutable
[
op_name
]
for
op_name
in
ops_names
]
self
.
mutable_list
.
append
(
mutable
)
mutable
.
registered_module
=
MixedOp
(
mutable
,
latency
)
def
on_forward_layer_choice
(
self
,
mutable
,
*
args
,
**
kwargs
):
"""
Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real
implementation is defined in mutator.
Parameters
----------
mutable: LayerChoice
forward logic of this input mutable
args: list of torch.Tensor
inputs of this mutable
kwargs: dict
inputs of this mutable
Returns
-------
torch.Tensor
output of this mutable, i.e., LayerChoice
int
index of the chosen op
"""
# FIXME: return mask, to be consistent with other algorithms
idx
=
mutable
.
registered_module
.
chosen_index
return
mutable
.
registered_module
(
mutable
,
*
args
,
**
kwargs
),
idx
def
num_arch_params
(
self
):
"""
The number of mutables, i.e., LayerChoice
Returns
-------
int
the number of LayerChoice in user model
"""
return
len
(
self
.
mutable_list
)
def
get_architecture_parameters
(
self
):
"""
Get all the architecture parameters.
yield
-----
PyTorch Parameter
Return path_alpha of the traversed mutable
"""
for
mutable
in
self
.
undedup_mutables
:
yield
mutable
.
registered_module
.
get_path_alpha
()
def
get_weighted_latency
(
self
):
"""
Get the latency weighted by gumbel softmax coefficients.
yield
-----
Tuple
Return the weighted_latency of the traversed mutable
"""
for
mutable
in
self
.
undedup_mutables
:
yield
mutable
.
registered_module
.
get_weighted_latency
()
def
set_temperature
(
self
,
temperature
):
"""
Set the annealed temperature of the op for gumbel softmax.
Parameters
----------
temperature : float
The annealed temperature for gumbel softmax
"""
for
mutable
in
self
.
undedup_mutables
:
mutable
.
registered_module
.
set_temperature
(
temperature
)
def
arch_requires_grad
(
self
):
"""
Make architecture weights require gradient
"""
for
mutable
in
self
.
undedup_mutables
:
mutable
.
registered_module
.
to_requires_grad
()
def
arch_disable_grad
(
self
):
"""
Disable gradient of architecture weights, i.e., does not
calculate gradient for them.
"""
for
mutable
in
self
.
undedup_mutables
:
mutable
.
registered_module
.
to_disable_grad
()
def
sample_final
(
self
):
"""
Generate the final chosen architecture.
Returns
-------
dict
the choice of each mutable, i.e., LayerChoice
"""
result
=
dict
()
for
mutable
in
self
.
undedup_mutables
:
assert
isinstance
(
mutable
,
LayerChoice
)
index
=
mutable
.
registered_module
.
chosen_index
# pylint: disable=not-callable
result
[
mutable
.
key
]
=
(
F
.
one_hot
(
torch
.
tensor
(
index
),
num_classes
=
len
(
mutable
))
.
view
(
-
1
)
.
bool
(),
)
return
result
nni/algorithms/nas/pytorch/fbnet/trainer.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
absolute_import
,
division
,
print_function
import
json
import
os
import
time
import
torch
import
numpy
as
np
from
torch.autograd
import
Variable
from
nni.nas.pytorch.base_trainer
import
BaseTrainer
from
nni.nas.pytorch.trainer
import
TorchTensorEncoder
from
nni.nas.pytorch.utils
import
AverageMeter
from
.mutator
import
FBNetMutator
from
.utils
import
RegularizerLoss
,
accuracy
class
FBNetTrainer
(
BaseTrainer
):
def
__init__
(
self
,
model
,
model_optim
,
criterion
,
device
,
device_ids
,
lookup_table
,
train_loader
,
valid_loader
,
n_epochs
=
120
,
load_ckpt
=
False
,
arch_path
=
None
,
logger
=
None
,
):
"""
Parameters
----------
model : pytorch model
the user model, which has mutables
model_optim : pytorch optimizer
the user defined optimizer
criterion : pytorch loss
the main task loss, nn.CrossEntropyLoss() is for classification
device : pytorch device
the devices to train/search the model
device_ids : list of int
the indexes of devices used for training
lookup_table : class
lookup table object for fbnet training
train_loader : pytorch data loader
data loader for the training set
valid_loader : pytorch data loader
data loader for the validation set
n_epochs : int
number of epochs to train/search
load_ckpt : bool
whether load checkpoint
arch_path : str
the path to store chosen architecture
logger : logger
the logger
"""
self
.
model
=
model
self
.
model_optim
=
model_optim
self
.
train_loader
=
train_loader
self
.
valid_loader
=
valid_loader
self
.
device
=
device
self
.
dev_num
=
len
(
device_ids
)
self
.
n_epochs
=
n_epochs
self
.
lookup_table
=
lookup_table
self
.
config
=
lookup_table
.
config
self
.
start_epoch
=
self
.
config
.
start_epoch
self
.
temp
=
self
.
config
.
init_temperature
self
.
exp_anneal_rate
=
self
.
config
.
exp_anneal_rate
self
.
mode
=
self
.
config
.
mode
self
.
load_ckpt
=
load_ckpt
self
.
arch_path
=
arch_path
self
.
logger
=
logger
# scheduler of learning rate
self
.
scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
model_optim
,
T_max
=
n_epochs
,
last_epoch
=-
1
)
# init mutator
self
.
mutator
=
FBNetMutator
(
model
,
lookup_table
)
self
.
mutator
.
set_temperature
(
self
.
temp
)
# DataParallel should be put behind the init of mutator
self
.
model
=
torch
.
nn
.
DataParallel
(
self
.
model
,
device_ids
=
device_ids
)
self
.
model
.
to
(
device
)
# build architecture optimizer
self
.
arch_optimizer
=
torch
.
optim
.
AdamW
(
self
.
mutator
.
get_architecture_parameters
(),
self
.
config
.
nas_lr
,
weight_decay
=
self
.
config
.
nas_weight_decay
,
)
self
.
reg_loss
=
RegularizerLoss
(
config
=
self
.
config
)
self
.
criterion
=
criterion
self
.
epoch
=
0
def
_layer_choice_sample
(
self
):
"""
Sample the index of network within layer choice
"""
stages
=
[
stage_name
for
stage_name
in
self
.
lookup_table
.
layer_num
]
stage_lnum
=
[
self
.
lookup_table
.
layer_num
[
stage
]
for
stage
in
stages
]
# get the choice idx in each layer
choice_ids
=
list
()
layer_id
=
0
for
param
in
self
.
mutator
.
get_architecture_parameters
():
param_np
=
param
.
cpu
().
detach
().
numpy
()
op_idx
=
np
.
argmax
(
param_np
)
choice_ids
.
append
(
op_idx
)
self
.
logger
.
info
(
"layer {}: {}, index: {}"
.
format
(
layer_id
,
param_np
,
op_idx
)
)
layer_id
+=
1
# get the arch_sample
choice_names
=
list
()
layer_id
=
0
for
i
,
stage_name
in
enumerate
(
stages
):
ops_names
=
[
op
for
op
in
self
.
lookup_table
.
lut_ops
[
stage_name
]]
for
_
in
range
(
stage_lnum
[
i
]):
searched_op
=
ops_names
[
choice_ids
[
layer_id
]]
choice_names
.
append
(
searched_op
)
layer_id
+=
1
self
.
logger
.
info
(
choice_names
)
return
choice_names
def
_get_perf_cost
(
self
,
requires_grad
=
True
):
"""
Get the accumulated performance cost.
"""
perf_cost
=
Variable
(
torch
.
zeros
(
1
),
requires_grad
=
requires_grad
).
to
(
self
.
device
,
non_blocking
=
True
)
for
latency
in
self
.
mutator
.
get_weighted_latency
():
perf_cost
=
perf_cost
+
latency
return
perf_cost
def
_validate
(
self
):
"""
Do validation. During validation, LayerChoices use the mixed-op.
Returns
-------
float, float, float
average loss, average top1 accuracy, average top5 accuracy
"""
self
.
valid_loader
.
batch_sampler
.
drop_last
=
False
batch_time
=
AverageMeter
(
"batch_time"
)
losses
=
AverageMeter
(
"losses"
)
top1
=
AverageMeter
(
"top1"
)
top5
=
AverageMeter
(
"top5"
)
# test on validation set under eval mode
self
.
model
.
eval
()
end
=
time
.
time
()
with
torch
.
no_grad
():
for
i
,
(
images
,
labels
)
in
enumerate
(
self
.
valid_loader
):
images
=
images
.
to
(
self
.
device
,
non_blocking
=
True
)
labels
=
labels
.
to
(
self
.
device
,
non_blocking
=
True
)
output
=
self
.
model
(
images
)
loss
=
self
.
criterion
(
output
,
labels
)
acc1
,
acc5
=
accuracy
(
output
,
labels
,
topk
=
(
1
,
5
))
losses
.
update
(
loss
,
images
.
size
(
0
))
top1
.
update
(
acc1
[
0
],
images
.
size
(
0
))
top5
.
update
(
acc5
[
0
],
images
.
size
(
0
))
# measure elapsed time
batch_time
.
update
(
time
.
time
()
-
end
)
end
=
time
.
time
()
if
i
%
10
==
0
or
i
+
1
==
len
(
self
.
valid_loader
):
test_log
=
(
"Valid"
+
": [{0}/{1}]
\t
"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})
\t
"
"Loss {loss.val:.4f} ({loss.avg:.4f})
\t
"
"Top-1 acc {top1.val:.3f} ({top1.avg:.3f})
\t
"
"Top-5 acc {top5.val:.3f} ({top5.avg:.3f})"
.
format
(
i
,
len
(
self
.
valid_loader
)
-
1
,
batch_time
=
batch_time
,
loss
=
losses
,
top1
=
top1
,
top5
=
top5
,
)
)
self
.
logger
.
info
(
test_log
)
return
losses
.
avg
,
top1
.
avg
,
top5
.
avg
def
_train_epoch
(
self
,
epoch
,
optimizer
,
arch_train
=
False
):
"""
Train one epoch.
"""
batch_time
=
AverageMeter
(
"batch_time"
)
data_time
=
AverageMeter
(
"data_time"
)
losses
=
AverageMeter
(
"losses"
)
top1
=
AverageMeter
(
"top1"
)
top5
=
AverageMeter
(
"top5"
)
# switch to train mode
self
.
model
.
train
()
data_loader
=
self
.
valid_loader
if
arch_train
else
self
.
train_loader
end
=
time
.
time
()
for
i
,
(
images
,
labels
)
in
enumerate
(
data_loader
):
data_time
.
update
(
time
.
time
()
-
end
)
images
=
images
.
to
(
self
.
device
,
non_blocking
=
True
)
labels
=
labels
.
to
(
self
.
device
,
non_blocking
=
True
)
output
=
self
.
model
(
images
)
loss
=
self
.
criterion
(
output
,
labels
)
# hardware-aware loss
perf_cost
=
self
.
_get_perf_cost
(
requires_grad
=
True
)
regu_loss
=
self
.
reg_loss
(
perf_cost
)
if
self
.
mode
.
startswith
(
"mul"
):
loss
=
loss
*
regu_loss
elif
self
.
mode
.
startswith
(
"add"
):
loss
=
loss
+
regu_loss
# measure accuracy and record loss
acc1
,
acc5
=
accuracy
(
output
,
labels
,
topk
=
(
1
,
5
))
losses
.
update
(
loss
.
item
(),
images
.
size
(
0
))
top1
.
update
(
acc1
[
0
].
item
(),
images
.
size
(
0
))
top5
.
update
(
acc5
[
0
].
item
(),
images
.
size
(
0
))
# compute gradient and do SGD step
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
# measure elapsed time
batch_time
.
update
(
time
.
time
()
-
end
)
end
=
time
.
time
()
if
i
%
10
==
0
:
batch_log
=
(
"Warmup Train [{0}][{1}]
\t
"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})
\t
"
"Data {data_time.val:.3f} ({data_time.avg:.3f})
\t
"
"Loss {losses.val:.4f} ({losses.avg:.4f})
\t
"
"Top-1 acc {top1.val:.3f} ({top1.avg:.3f})
\t
"
"Top-5 acc {top5.val:.3f} ({top5.avg:.3f})
\t
"
.
format
(
epoch
+
1
,
i
,
batch_time
=
batch_time
,
data_time
=
data_time
,
losses
=
losses
,
top1
=
top1
,
top5
=
top5
,
)
)
self
.
logger
.
info
(
batch_log
)
def
_warm_up
(
self
):
"""
Warm up the model, while the architecture weights are not trained.
"""
for
epoch
in
range
(
self
.
epoch
,
self
.
start_epoch
):
self
.
logger
.
info
(
"
\n
--------Warmup epoch: %d--------
\n
"
,
epoch
+
1
)
self
.
_train_epoch
(
epoch
,
self
.
model_optim
)
# adjust learning rate
self
.
scheduler
.
step
()
# validation
val_loss
,
val_top1
,
val_top5
=
self
.
_validate
()
val_log
=
(
"Warmup Valid [{0}/{1}]
\t
"
"loss {2:.3f}
\t
top-1 acc {3:.3f}
\t
top-5 acc {4:.3f}"
.
format
(
epoch
+
1
,
self
.
warmup_epochs
,
val_loss
,
val_top1
,
val_top5
)
)
self
.
logger
.
info
(
val_log
)
if
epoch
%
10
==
0
:
filename
=
os
.
path
.
join
(
self
.
config
.
model_dir
,
"checkpoint_%s.pth"
%
epoch
)
self
.
save_checkpoint
(
epoch
,
filename
)
def
_train
(
self
):
"""
Train the model, it trains model weights and architecute weights.
Architecture weights are trained according to the schedule.
Before updating architecture weights, ```requires_grad``` is enabled.
Then, it is disabled after the updating, in order not to update
architecture weights when training model weights.
"""
arch_param_num
=
self
.
mutator
.
num_arch_params
()
self
.
logger
.
info
(
"#arch_params: {}"
.
format
(
arch_param_num
))
self
.
epoch
=
max
(
self
.
start_epoch
,
self
.
epoch
)
ckpt_path
=
self
.
config
.
model_dir
choice_names
=
None
top1_best
=
0.0
for
epoch
in
range
(
self
.
epoch
,
self
.
n_epochs
):
self
.
logger
.
info
(
"
\n
--------Train epoch: %d--------
\n
"
,
epoch
+
1
)
# update the weight parameters
self
.
_train_epoch
(
epoch
,
self
.
model_optim
)
# adjust learning rate
self
.
scheduler
.
step
()
self
.
logger
.
info
(
"Update architecture parameters"
)
# update the architecture parameters
self
.
mutator
.
arch_requires_grad
()
self
.
_train_epoch
(
epoch
,
self
.
arch_optimizer
,
True
)
self
.
mutator
.
arch_disable_grad
()
# temperature annealing
self
.
temp
=
self
.
temp
*
self
.
exp_anneal_rate
self
.
mutator
.
set_temperature
(
self
.
temp
)
# sample the architecture of sub-network
choice_names
=
self
.
_layer_choice_sample
()
# validate
val_loss
,
val_top1
,
val_top5
=
self
.
_validate
()
val_log
=
(
"Valid [{0}]
\t
"
"loss {1:.3f}
\t
top-1 acc {2:.3f}
\t
top-5 acc {3:.3f}"
.
format
(
epoch
+
1
,
val_loss
,
val_top1
,
val_top5
)
)
self
.
logger
.
info
(
val_log
)
if
epoch
%
10
==
0
:
filename
=
os
.
path
.
join
(
ckpt_path
,
"checkpoint_%s.pth"
%
epoch
)
self
.
save_checkpoint
(
epoch
,
filename
,
choice_names
)
val_top1
=
val_top1
.
cpu
().
as_numpy
()
if
val_top1
>
top1_best
:
filename
=
os
.
path
.
join
(
ckpt_path
,
"checkpoint_best.pth"
)
self
.
save_checkpoint
(
epoch
,
filename
,
choice_names
)
top1_best
=
val_top1
def
save_checkpoint
(
self
,
epoch
,
filename
,
choice_names
=
None
):
"""
Save checkpoint of the whole model.
Saving model weights and architecture weights as ```filename```,
and saving currently chosen architecture in ```arch_path```.
"""
state
=
{
"model"
:
self
.
model
.
state_dict
(),
"optim"
:
self
.
model_optim
.
state_dict
(),
"epoch"
:
epoch
,
"arch_sample"
:
choice_names
,
}
torch
.
save
(
state
,
filename
)
self
.
logger
.
info
(
"Save checkpoint to {0:}"
.
format
(
filename
))
if
self
.
arch_path
:
self
.
export
(
self
.
arch_path
)
def
load_checkpoint
(
self
,
filename
):
"""
Load the checkpoint from ```ckpt_path```.
"""
ckpt
=
torch
.
load
(
filename
)
self
.
epoch
=
ckpt
[
"epoch"
]
self
.
model
.
load_state_dict
(
ckpt
[
"model"
])
self
.
model_optim
.
load_state_dict
(
ckpt
[
"optim"
])
def
train
(
self
):
"""
Train the whole model.
"""
if
self
.
load_ckpt
:
ckpt_path
=
self
.
config
.
model_dir
filename
=
os
.
path
.
join
(
ckpt_path
,
"checkpoint_best.pth"
)
if
os
.
path
.
exists
(
filename
):
self
.
load_checkpoint
(
filename
)
if
self
.
epoch
<
self
.
start_epoch
:
self
.
_warm_up
()
self
.
_train
()
def
export
(
self
,
file_name
):
"""
Export the chosen architecture into a file
Parameters
----------
file_name : str
the file that stores exported chosen architecture
"""
exported_arch
=
self
.
mutator
.
sample_final
()
with
open
(
file_name
,
"w"
)
as
f
:
json
.
dump
(
exported_arch
,
f
,
indent
=
2
,
sort_keys
=
True
,
cls
=
TorchTensorEncoder
,
)
def
validate
(
self
):
raise
NotImplementedError
def
checkpoint
(
self
):
raise
NotImplementedError
nni/algorithms/nas/pytorch/fbnet/utils.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
absolute_import
,
division
,
print_function
import
ast
import
os
import
timeit
import
torch
import
numpy
as
np
import
torch.nn
as
nn
from
nni.compression.pytorch.utils
import
count_flops_params
LUT_FILE
=
"lut.npy"
LUT_JSON_FILE
=
"lut.txt"
LUT_PATH
=
"lut"
DATA_TYPE
=
"float"
class
NASConfig
:
def
__init__
(
self
,
perf_metric
=
"flops"
,
lut_load
=
False
,
lut_load_format
=
"json"
,
model_dir
=
None
,
nas_lr
=
0.01
,
nas_weight_decay
=
5e-4
,
mode
=
"mul"
,
alpha
=
0.25
,
beta
=
0.6
,
start_epoch
=
50
,
init_temperature
=
5.0
,
exp_anneal_rate
=
np
.
exp
(
-
0.045
),
search_space
=
None
,
):
# LUT of performance metric
# flops means the multiplies, latency means the time cost on platform
self
.
perf_metric
=
perf_metric
assert
perf_metric
in
[
"flops"
,
"latency"
,
],
"perf_metric should be ['flops', 'latency']"
# wether load or create lut file
self
.
lut_load
=
lut_load
assert
lut_load_format
in
[
"json"
,
"numpy"
,
],
"lut_load_format should be ['json', 'numpy']"
self
.
lut_load_format
=
lut_load_format
# necessary dirs
self
.
lut_en
=
model_dir
is
not
None
if
self
.
lut_en
:
self
.
model_dir
=
model_dir
os
.
makedirs
(
model_dir
,
exist_ok
=
True
)
self
.
lut_path
=
os
.
path
.
join
(
model_dir
,
LUT_PATH
)
os
.
makedirs
(
self
.
lut_path
,
exist_ok
=
True
)
# NAS learning setting
self
.
nas_lr
=
nas_lr
self
.
nas_weight_decay
=
nas_weight_decay
# hardware-aware loss setting
self
.
mode
=
mode
assert
mode
in
[
"mul"
,
"add"
],
"mode should be ['mul', 'add']"
self
.
alpha
=
alpha
self
.
beta
=
beta
# NAS training setting
self
.
start_epoch
=
start_epoch
self
.
init_temperature
=
init_temperature
self
.
exp_anneal_rate
=
exp_anneal_rate
# definition of search blocks and space
self
.
search_space
=
search_space
class
RegularizerLoss
(
nn
.
Module
):
"""Auxilliary loss for hardware-aware NAS."""
def
__init__
(
self
,
config
):
"""
Parameters
----------
config : class
to manage the configuration for NAS training, and search space etc.
"""
super
(
RegularizerLoss
,
self
).
__init__
()
self
.
mode
=
config
.
mode
self
.
alpha
=
config
.
alpha
self
.
beta
=
config
.
beta
def
forward
(
self
,
perf_cost
,
batch_size
=
1
):
"""
Parameters
----------
perf_cost : tensor
the accumulated performance cost
batch_size : int
batch size for normalization
Returns
-------
output: tensor
the hardware-aware constraint loss
"""
if
self
.
mode
==
"mul"
:
log_loss
=
torch
.
log
(
perf_cost
/
batch_size
)
**
self
.
beta
return
self
.
alpha
*
log_loss
elif
self
.
mode
==
"add"
:
linear_loss
=
(
perf_cost
/
batch_size
)
**
self
.
beta
return
self
.
alpha
*
linear_loss
else
:
raise
NotImplementedError
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
"""
Computes the precision@k for the specified values of k
Parameters
----------
output : pytorch tensor
output, e.g., predicted value
target : pytorch tensor
label
topk : tuple
specify top1 and top5
Returns
-------
list
accuracy of top1 and top5
"""
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
,
keepdim
=
True
)
res
.
append
(
correct_k
.
mul_
(
100.0
/
batch_size
))
return
res
def
supernet_sample
(
model
,
state_dict
,
sampled_arch
=
[],
lookup_table
=
None
):
"""
Initialize the searched sub-model from supernet.
Parameters
----------
model : pytorch model
the created subnet
state_dict : checkpoint
the checkpoint of supernet, including the pre-trained params
sampled_arch : list of str
the searched layer names of the subnet
lookup_table : class
to manage the candidate ops, layer information and layer performance
"""
replace
=
list
()
stages
=
[
stage
for
stage
in
lookup_table
.
layer_num
]
stage_lnum
=
[
lookup_table
.
layer_num
[
stage
]
for
stage
in
stages
]
if
sampled_arch
:
layer_id
=
0
for
i
,
stage
in
enumerate
(
stages
):
ops_names
=
[
op_name
for
op_name
in
lookup_table
.
lut_ops
[
stage
]]
for
_
in
range
(
stage_lnum
[
i
]):
searched_op
=
sampled_arch
[
layer_id
]
op_i
=
ops_names
.
index
(
searched_op
)
replace
.
append
(
[
"blocks.{}."
.
format
(
layer_id
),
"blocks.{}.op."
.
format
(
layer_id
),
"blocks.{}.{}."
.
format
(
layer_id
,
op_i
),
]
)
layer_id
+=
1
model_init
(
model
,
state_dict
,
replace
=
replace
)
def
model_init
(
model
,
state_dict
,
replace
=
[]):
"""Initialize the model from state_dict."""
prefix
=
"module."
param_dict
=
dict
()
for
k
,
v
in
state_dict
.
items
():
if
k
.
startswith
(
prefix
):
k
=
k
[
7
:]
param_dict
[
k
]
=
v
for
k
,
(
name
,
m
)
in
enumerate
(
model
.
named_modules
()):
if
replace
:
for
layer_replace
in
replace
:
assert
len
(
layer_replace
)
==
3
,
"The elements should be three."
pre_scope
,
key
,
replace_key
=
layer_replace
if
pre_scope
in
name
:
name
=
name
.
replace
(
key
,
replace_key
)
# Copy the state_dict to current model
if
(
name
+
".weight"
in
param_dict
)
or
(
name
+
".running_mean"
in
param_dict
):
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
shape
=
m
.
running_mean
.
shape
if
shape
==
param_dict
[
name
+
".running_mean"
].
shape
:
if
m
.
weight
is
not
None
:
m
.
weight
.
data
=
param_dict
[
name
+
".weight"
]
m
.
bias
.
data
=
param_dict
[
name
+
".bias"
]
m
.
running_mean
=
param_dict
[
name
+
".running_mean"
]
m
.
running_var
=
param_dict
[
name
+
".running_var"
]
elif
isinstance
(
m
,
nn
.
Conv2d
)
or
isinstance
(
m
,
nn
.
Linear
):
shape
=
m
.
weight
.
data
.
shape
if
shape
==
param_dict
[
name
+
".weight"
].
shape
:
m
.
weight
.
data
=
param_dict
[
name
+
".weight"
]
if
m
.
bias
is
not
None
:
m
.
bias
.
data
=
param_dict
[
name
+
".bias"
]
elif
isinstance
(
m
,
nn
.
ConvTranspose2d
):
m
.
weight
.
data
=
param_dict
[
name
+
".weight"
]
if
m
.
bias
is
not
None
:
m
.
bias
.
data
=
param_dict
[
name
+
".bias"
]
class
LookUpTable
:
"""Build look-up table for NAS."""
def
__init__
(
self
,
config
,
primitives
):
"""
Parameters
----------
config : class
to manage the configuration for NAS training, and search space etc.
"""
self
.
config
=
config
# definition of search blocks and space
self
.
search_space
=
config
.
search_space
# layers for NAS
self
.
cnt_layers
=
len
(
self
.
search_space
[
"input_shape"
])
# constructors for each operation
self
.
lut_ops
=
{
stage_name
:
{
op_name
:
primitives
[
op_name
]
for
op_name
in
self
.
search_space
[
"stages"
][
stage_name
][
"ops"
]
}
for
stage_name
in
self
.
search_space
[
"stages"
]
}
self
.
layer_num
=
{
stage_name
:
self
.
search_space
[
"stages"
][
stage_name
][
"layer_num"
]
for
stage_name
in
self
.
search_space
[
"stages"
]
}
# arguments for the ops constructors, input_shapes just for convinience
self
.
layer_configs
,
self
.
layer_in_shapes
=
self
.
_layer_configs
()
# lookup_table
self
.
perf_metric
=
config
.
perf_metric
if
config
.
lut_en
:
self
.
lut_perf
=
None
self
.
lut_file
=
os
.
path
.
join
(
config
.
lut_path
,
LUT_FILE
)
self
.
lut_json_file
=
LUT_JSON_FILE
if
config
.
lut_load
:
if
config
.
lut_load_format
==
"numpy"
:
# Load data from numpy file
self
.
_load_from_file
()
else
:
# Load data from json file
self
.
_load_from_json_file
()
else
:
self
.
_create_perfs
()
def
_layer_configs
(
self
):
"""Generate basic params for different layers."""
# layer_configs are : c_in, c_out, stride, fm_size
layer_configs
=
[
[
self
.
search_space
[
"input_shape"
][
layer_id
][
0
],
self
.
search_space
[
"channel_size"
][
layer_id
],
self
.
search_space
[
"strides"
][
layer_id
],
self
.
search_space
[
"fm_size"
][
layer_id
],
]
for
layer_id
in
range
(
self
.
cnt_layers
)
]
# layer_in_shapes are (C_in, input_w, input_h)
layer_in_shapes
=
self
.
search_space
[
"input_shape"
]
return
layer_configs
,
layer_in_shapes
def
_create_perfs
(
self
,
cnt_of_runs
=
200
):
"""Create performance cost for each op."""
if
self
.
perf_metric
==
"latency"
:
self
.
lut_perf
=
self
.
_calculate_latency
(
cnt_of_runs
)
elif
self
.
perf_metric
==
"flops"
:
self
.
lut_perf
=
self
.
_calculate_flops
()
self
.
_write_lut_to_file
()
def
_calculate_flops
(
self
,
eps
=
0.001
):
"""FLOPs cost."""
flops_lut
=
[{}
for
i
in
range
(
self
.
cnt_layers
)]
layer_id
=
0
for
stage_name
in
self
.
lut_ops
:
stage_ops
=
self
.
lut_ops
[
stage_name
]
ops_num
=
self
.
layer_num
[
stage_name
]
for
_
in
range
(
ops_num
):
for
op_name
in
stage_ops
:
layer_config
=
self
.
layer_configs
[
layer_id
]
key_params
=
{
"fm_size"
:
layer_config
[
3
]}
op
=
stage_ops
[
op_name
](
*
layer_config
[
0
:
3
],
**
key_params
)
# measured in Flops
in_shape
=
self
.
layer_in_shapes
[
layer_id
]
x
=
(
1
,
in_shape
[
0
],
in_shape
[
1
],
in_shape
[
2
])
flops
,
_
,
_
=
count_flops_params
(
op
,
x
,
verbose
=
False
)
flops
=
eps
if
flops
==
0.0
else
flops
flops_lut
[
layer_id
][
op_name
]
=
float
(
flops
)
layer_id
+=
1
return
flops_lut
def
_calculate_latency
(
self
,
cnt_of_runs
):
"""Latency cost."""
LATENCY_BATCH_SIZE
=
1
latency_lut
=
[{}
for
i
in
range
(
self
.
cnt_layers
)]
layer_id
=
0
for
stage_name
in
self
.
lut_ops
:
stage_ops
=
self
.
lut_ops
[
stage_name
]
ops_num
=
self
.
layer_num
[
stage_name
]
for
_
in
range
(
ops_num
):
for
op_name
in
stage_ops
:
layer_config
=
self
.
layer_configs
[
layer_id
]
key_params
=
{
"fm_size"
:
layer_config
[
3
]}
op
=
stage_ops
[
op_name
](
*
layer_config
[
0
:
3
],
**
key_params
)
input_data
=
torch
.
randn
(
(
LATENCY_BATCH_SIZE
,
*
self
.
layer_in_shapes
[
layer_id
])
)
globals
()[
"op"
],
globals
()[
"input_data"
]
=
op
,
input_data
total_time
=
timeit
.
timeit
(
"output = op(input_data)"
,
setup
=
"gc.enable()"
,
globals
=
globals
(),
number
=
cnt_of_runs
,
)
# measured in micro-second
latency_lut
[
layer_id
][
op_name
]
=
(
total_time
/
cnt_of_runs
/
LATENCY_BATCH_SIZE
*
1e6
)
layer_id
+=
1
return
latency_lut
def
_write_lut_to_file
(
self
):
"""Save lut as numpy file."""
np
.
save
(
self
.
lut_file
,
self
.
lut_perf
)
def
_load_from_file
(
self
):
"""Load numpy file."""
self
.
lut_perf
=
np
.
load
(
self
.
lut_file
,
allow_pickle
=
True
)
def
_load_from_json_file
(
self
):
"""Load json file."""
"""
lut_json_file ('lut.txt') format:
{'op_name': operator_name,
'op_data_shape': (input_w, input_h, C_in, C_out, stride),
'op_dtype': data_type,
'op_latency': latency}
{...}
{...}
"""
latency_file
=
open
(
self
.
lut_json_file
,
"r"
)
ops_latency
=
latency_file
.
readlines
()
"""ops_lut: {'op_name': {'op_data_shape': {'op_dtype': latency}}}"""
ops_lut
=
{}
for
op_latency
in
ops_latency
:
assert
isinstance
(
op_latency
,
str
)
or
isinstance
(
op_latency
,
dict
)
if
isinstance
(
op_latency
,
str
):
record
=
ast
.
literal_eval
(
op_latency
)
elif
isinstance
(
op_latency
,
dict
):
record
=
op_latency
op_name
=
record
[
"op_name"
]
"""op_data_shape: (input_w, input_h, C_in, C_out, stride)"""
op_data_shape
=
record
[
"op_data_shape"
]
op_dtype
=
record
[
"op_dtype"
]
op_latency
=
record
[
"op_latency"
]
if
op_name
not
in
ops_lut
:
ops_lut
[
op_name
]
=
{}
if
op_data_shape
not
in
ops_lut
[
op_name
]:
ops_lut
[
op_name
][
op_data_shape
]
=
{}
ops_lut
[
op_name
][
op_data_shape
][
op_dtype
]
=
op_latency
self
.
lut_perf
=
[{}
for
i
in
range
(
self
.
cnt_layers
)]
layer_id
=
0
for
stage_name
in
self
.
lut_ops
:
stage_ops
=
self
.
lut_ops
[
stage_name
]
ops_num
=
self
.
layer_num
[
stage_name
]
for
_
in
range
(
ops_num
):
for
op_name
in
stage_ops
:
layer_config
=
self
.
layer_configs
[
layer_id
]
layer_in_shape
=
self
.
layer_in_shapes
[
layer_id
]
input_w
=
layer_in_shape
[
1
]
input_h
=
layer_in_shape
[
2
]
c_in
=
layer_config
[
0
]
c_out
=
layer_config
[
1
]
stride
=
layer_config
[
2
]
op_data_shape
=
(
input_w
,
input_h
,
c_in
,
c_out
,
stride
)
if
op_name
in
ops_lut
and
op_data_shape
in
ops_lut
[
op_name
]:
self
.
lut_perf
[
layer_id
][
op_name
]
=
\
ops_lut
[
op_name
][
op_data_shape
][
DATA_TYPE
]
layer_id
+=
1
nni/algorithms/nas/pytorch/pdarts/mutator.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
numpy
as
np
import
torch
from
torch
import
nn
from
nni.algorithms.nas.pytorch.darts
import
DartsMutator
from
nni.nas.pytorch.mutables
import
LayerChoice
class
PdartsMutator
(
DartsMutator
):
"""
It works with PdartsTrainer to calculate ops weights,
and drop weights in different PDARTS epochs.
"""
def
__init__
(
self
,
model
,
pdarts_epoch_index
,
pdarts_num_to_drop
,
switches
=
{}):
self
.
pdarts_epoch_index
=
pdarts_epoch_index
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
if
switches
is
None
:
self
.
switches
=
{}
else
:
self
.
switches
=
switches
super
(
PdartsMutator
,
self
).
__init__
(
model
)
# this loop go through mutables with different keys,
# it's mainly to update length of choices.
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
switches
=
self
.
switches
.
get
(
mutable
.
key
,
[
True
for
j
in
range
(
len
(
mutable
))])
choices
=
self
.
choices
[
mutable
.
key
]
operations_count
=
np
.
sum
(
switches
)
# +1 and -1 are caused by zero operation in darts network
# the zero operation is not in choices list in network, but its weight are in,
# so it needs one more weights and switch for zero.
self
.
choices
[
mutable
.
key
]
=
nn
.
Parameter
(
1.0E-3
*
torch
.
randn
(
operations_count
+
1
))
self
.
switches
[
mutable
.
key
]
=
switches
# update LayerChoice instances in model,
# it's physically remove dropped choices operations.
for
module
in
self
.
model
.
modules
():
if
isinstance
(
module
,
LayerChoice
):
switches
=
self
.
switches
.
get
(
module
.
key
)
choices
=
self
.
choices
[
module
.
key
]
if
len
(
module
)
>
len
(
choices
):
# from last to first, so that it won't effect previous indexes after removed one.
for
index
in
range
(
len
(
switches
)
-
1
,
-
1
,
-
1
):
if
switches
[
index
]
==
False
:
del
module
[
index
]
assert
len
(
module
)
<=
len
(
choices
),
"Failed to remove dropped choices."
def
export
(
self
):
# Cannot rely on super().export() because P-DARTS has deleted some of the choices and has misaligned length.
results
=
super
().
sample_final
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
# As some operations are dropped physically,
# so it needs to fill back false to track dropped operations.
trained_result
=
results
[
mutable
.
key
]
trained_index
=
0
switches
=
self
.
switches
[
mutable
.
key
]
result
=
torch
.
Tensor
(
switches
).
bool
()
for
index
in
range
(
len
(
result
)):
if
result
[
index
]:
result
[
index
]
=
trained_result
[
trained_index
]
trained_index
+=
1
results
[
mutable
.
key
]
=
result
return
results
def
drop_paths
(
self
):
"""
This method is called when a PDARTS epoch is finished.
It prepares switches for next epoch.
candidate operations with False switch will be doppped in next epoch.
"""
all_switches
=
copy
.
deepcopy
(
self
.
switches
)
for
key
in
all_switches
:
switches
=
all_switches
[
key
]
idxs
=
[]
for
j
in
range
(
len
(
switches
)):
if
switches
[
j
]:
idxs
.
append
(
j
)
sorted_weights
=
self
.
choices
[
key
].
data
.
cpu
().
numpy
()[:
-
1
]
drop
=
np
.
argsort
(
sorted_weights
)[:
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
]]
for
idx
in
drop
:
switches
[
idxs
[
idx
]]
=
False
return
all_switches
nni/algorithms/nas/pytorch/pdarts/trainer.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
logging
from
nni.nas.pytorch.callbacks
import
LRSchedulerCallback
from
nni.algorithms.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.trainer
import
BaseTrainer
,
TorchTensorEncoder
from
.mutator
import
PdartsMutator
logger
=
logging
.
getLogger
(
__name__
)
class
PdartsTrainer
(
BaseTrainer
):
"""
This trainer implements the PDARTS algorithm.
PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
pdarts_num_layers means how many layers more than first epoch.
pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
So that the grew network can in similar size.
"""
def
__init__
(
self
,
model_creator
,
init_layers
,
metrics
,
num_epochs
,
dataset_train
,
dataset_valid
,
pdarts_num_layers
=
[
0
,
6
,
12
],
pdarts_num_to_drop
=
[
3
,
2
,
1
],
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
unrolled
=
False
):
super
(
PdartsTrainer
,
self
).
__init__
()
self
.
model_creator
=
model_creator
self
.
init_layers
=
init_layers
self
.
pdarts_num_layers
=
pdarts_num_layers
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
self
.
pdarts_epoch
=
len
(
pdarts_num_to_drop
)
self
.
darts_parameters
=
{
"metrics"
:
metrics
,
"num_epochs"
:
num_epochs
,
"dataset_train"
:
dataset_train
,
"dataset_valid"
:
dataset_valid
,
"batch_size"
:
batch_size
,
"workers"
:
workers
,
"device"
:
device
,
"log_frequency"
:
log_frequency
,
"unrolled"
:
unrolled
}
self
.
callbacks
=
callbacks
if
callbacks
is
not
None
else
[]
def
train
(
self
):
switches
=
None
for
epoch
in
range
(
self
.
pdarts_epoch
):
layers
=
self
.
init_layers
+
self
.
pdarts_num_layers
[
epoch
]
model
,
criterion
,
optim
,
lr_scheduler
=
self
.
model_creator
(
layers
)
self
.
mutator
=
PdartsMutator
(
model
,
epoch
,
self
.
pdarts_num_to_drop
,
switches
)
for
callback
in
self
.
callbacks
:
callback
.
build
(
model
,
self
.
mutator
,
self
)
callback
.
on_epoch_begin
(
epoch
)
darts_callbacks
=
[]
if
lr_scheduler
is
not
None
:
darts_callbacks
.
append
(
LRSchedulerCallback
(
lr_scheduler
))
self
.
trainer
=
DartsTrainer
(
model
,
mutator
=
self
.
mutator
,
loss
=
criterion
,
optimizer
=
optim
,
callbacks
=
darts_callbacks
,
**
self
.
darts_parameters
)
logger
.
info
(
"start pdarts training epoch %s..."
,
epoch
)
self
.
trainer
.
train
()
switches
=
self
.
mutator
.
drop_paths
()
for
callback
in
self
.
callbacks
:
callback
.
on_epoch_end
(
epoch
)
def
validate
(
self
):
self
.
trainer
.
validate
()
def
export
(
self
,
file
):
mutator_export
=
self
.
mutator
.
export
()
with
open
(
file
,
"w"
)
as
f
:
json
.
dump
(
mutator_export
,
f
,
indent
=
2
,
sort_keys
=
True
,
cls
=
TorchTensorEncoder
)
def
checkpoint
(
self
):
raise
NotImplementedError
(
"Not implemented yet"
)
nni/algorithms/nas/pytorch/proxylessnas/__init__.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.mutator
import
ProxylessNasMutator
from
.trainer
import
ProxylessNasTrainer
nni/algorithms/nas/pytorch/proxylessnas/mutator.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
math
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
import
numpy
as
np
from
nni.nas.pytorch.base_mutator
import
BaseMutator
from
nni.nas.pytorch.mutables
import
LayerChoice
from
.utils
import
detach_variable
class
ArchGradientFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
binary_gates
,
run_func
,
backward_func
):
ctx
.
run_func
=
run_func
ctx
.
backward_func
=
backward_func
detached_x
=
detach_variable
(
x
)
with
torch
.
enable_grad
():
output
=
run_func
(
detached_x
)
ctx
.
save_for_backward
(
detached_x
,
output
)
return
output
.
data
@
staticmethod
def
backward
(
ctx
,
grad_output
):
detached_x
,
output
=
ctx
.
saved_tensors
grad_x
=
torch
.
autograd
.
grad
(
output
,
detached_x
,
grad_output
,
only_inputs
=
True
)
# compute gradients w.r.t. binary_gates
binary_grads
=
ctx
.
backward_func
(
detached_x
.
data
,
output
.
data
,
grad_output
.
data
)
return
grad_x
[
0
],
binary_grads
,
None
,
None
class
MixedOp
(
nn
.
Module
):
"""
This class is to instantiate and manage info of one LayerChoice.
It includes architecture weights, binary weights, and member functions
operating the weights.
forward_mode:
forward/backward mode for LayerChoice: None, two, full, and full_v2.
For training architecture weights, we use full_v2 by default, and for training
model weights, we use None.
"""
forward_mode
=
None
def
__init__
(
self
,
mutable
):
"""
Parameters
----------
mutable : LayerChoice
A LayerChoice in user model
"""
super
(
MixedOp
,
self
).
__init__
()
self
.
ap_path_alpha
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
mutable
)))
self
.
ap_path_wb
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
mutable
)))
self
.
ap_path_alpha
.
requires_grad
=
False
self
.
ap_path_wb
.
requires_grad
=
False
self
.
active_index
=
[
0
]
self
.
inactive_index
=
None
self
.
log_prob
=
None
self
.
current_prob_over_ops
=
None
self
.
n_choices
=
len
(
mutable
)
def
get_ap_path_alpha
(
self
):
return
self
.
ap_path_alpha
def
to_requires_grad
(
self
):
self
.
ap_path_alpha
.
requires_grad
=
True
self
.
ap_path_wb
.
requires_grad
=
True
def
to_disable_grad
(
self
):
self
.
ap_path_alpha
.
requires_grad
=
False
self
.
ap_path_wb
.
requires_grad
=
False
def
forward
(
self
,
mutable
,
x
):
"""
Define forward of LayerChoice. For 'full_v2', backward is also defined.
The 'two' mode is explained in section 3.2.1 in the paper.
The 'full_v2' mode is explained in Appendix D in the paper.
Parameters
----------
mutable : LayerChoice
this layer's mutable
x : tensor
inputs of this layer, only support one input
Returns
-------
output: tensor
output of this layer
"""
if
MixedOp
.
forward_mode
==
'full'
or
MixedOp
.
forward_mode
==
'two'
:
output
=
0
for
_i
in
self
.
active_index
:
oi
=
self
.
candidate_ops
[
_i
](
x
)
output
=
output
+
self
.
ap_path_wb
[
_i
]
*
oi
for
_i
in
self
.
inactive_index
:
oi
=
self
.
candidate_ops
[
_i
](
x
)
output
=
output
+
self
.
ap_path_wb
[
_i
]
*
oi
.
detach
()
elif
MixedOp
.
forward_mode
==
'full_v2'
:
def
run_function
(
key
,
candidate_ops
,
active_id
):
def
forward
(
_x
):
return
candidate_ops
[
active_id
](
_x
)
return
forward
def
backward_function
(
key
,
candidate_ops
,
active_id
,
binary_gates
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
len
(
candidate_ops
)):
if
k
!=
active_id
:
out_k
=
candidate_ops
[
k
](
_x
.
data
)
else
:
out_k
=
_output
.
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
output
=
ArchGradientFunction
.
apply
(
x
,
self
.
ap_path_wb
,
run_function
(
mutable
.
key
,
list
(
mutable
),
self
.
active_index
[
0
]),
backward_function
(
mutable
.
key
,
list
(
mutable
),
self
.
active_index
[
0
],
self
.
ap_path_wb
))
else
:
output
=
self
.
active_op
(
mutable
)(
x
)
return
output
@
property
def
probs_over_ops
(
self
):
"""
Apply softmax on alpha to generate probability distribution
Returns
-------
pytorch tensor
probability distribution
"""
probs
=
F
.
softmax
(
self
.
ap_path_alpha
,
dim
=
0
)
# softmax to probability
return
probs
@
property
def
chosen_index
(
self
):
"""
choose the op with max prob
Returns
-------
int
index of the chosen one
numpy.float32
prob of the chosen one
"""
probs
=
self
.
probs_over_ops
.
data
.
cpu
().
numpy
()
index
=
int
(
np
.
argmax
(
probs
))
return
index
,
probs
[
index
]
def
active_op
(
self
,
mutable
):
"""
assume only one path is active
Returns
-------
PyTorch module
the chosen operation
"""
return
mutable
[
self
.
active_index
[
0
]]
@
property
def
active_op_index
(
self
):
"""
return active op's index, the active op is sampled
Returns
-------
int
index of the active op
"""
return
self
.
active_index
[
0
]
def
set_chosen_op_active
(
self
):
"""
set chosen index, active and inactive indexes
"""
chosen_idx
,
_
=
self
.
chosen_index
self
.
active_index
=
[
chosen_idx
]
self
.
inactive_index
=
[
_i
for
_i
in
range
(
0
,
chosen_idx
)]
+
\
[
_i
for
_i
in
range
(
chosen_idx
+
1
,
self
.
n_choices
)]
def
binarize
(
self
,
mutable
):
"""
Sample based on alpha, and set binary weights accordingly.
ap_path_wb is set in this function, which is called binarize.
Parameters
----------
mutable : LayerChoice
this layer's mutable
"""
self
.
log_prob
=
None
# reset binary gates
self
.
ap_path_wb
.
data
.
zero_
()
probs
=
self
.
probs_over_ops
if
MixedOp
.
forward_mode
==
'two'
:
# sample two ops according to probs
sample_op
=
torch
.
multinomial
(
probs
.
data
,
2
,
replacement
=
False
)
probs_slice
=
F
.
softmax
(
torch
.
stack
([
self
.
ap_path_alpha
[
idx
]
for
idx
in
sample_op
]),
dim
=
0
)
self
.
current_prob_over_ops
=
torch
.
zeros_like
(
probs
)
for
i
,
idx
in
enumerate
(
sample_op
):
self
.
current_prob_over_ops
[
idx
]
=
probs_slice
[
i
]
# choose one to be active and the other to be inactive according to probs_slice
c
=
torch
.
multinomial
(
probs_slice
.
data
,
1
)[
0
]
# 0 or 1
active_op
=
sample_op
[
c
].
item
()
inactive_op
=
sample_op
[
1
-
c
].
item
()
self
.
active_index
=
[
active_op
]
self
.
inactive_index
=
[
inactive_op
]
# set binary gate
self
.
ap_path_wb
.
data
[
active_op
]
=
1.0
else
:
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
active_index
=
[
sample
]
self
.
inactive_index
=
[
_i
for
_i
in
range
(
0
,
sample
)]
+
\
[
_i
for
_i
in
range
(
sample
+
1
,
len
(
mutable
))]
self
.
log_prob
=
torch
.
log
(
probs
[
sample
])
self
.
current_prob_over_ops
=
probs
self
.
ap_path_wb
.
data
[
sample
]
=
1.0
# avoid over-regularization
for
choice
in
mutable
:
for
_
,
param
in
choice
.
named_parameters
():
param
.
grad
=
None
@
staticmethod
def
delta_ij
(
i
,
j
):
if
i
==
j
:
return
1
else
:
return
0
def
set_arch_param_grad
(
self
,
mutable
):
"""
Calculate alpha gradient for this LayerChoice.
It is calculated using gradient of binary gate, probs of ops.
"""
binary_grads
=
self
.
ap_path_wb
.
grad
.
data
if
self
.
active_op
(
mutable
).
is_zero_layer
():
self
.
ap_path_alpha
.
grad
=
None
return
if
self
.
ap_path_alpha
.
grad
is
None
:
self
.
ap_path_alpha
.
grad
=
torch
.
zeros_like
(
self
.
ap_path_alpha
.
data
)
if
MixedOp
.
forward_mode
==
'two'
:
involved_idx
=
self
.
active_index
+
self
.
inactive_index
probs_slice
=
F
.
softmax
(
torch
.
stack
([
self
.
ap_path_alpha
[
idx
]
for
idx
in
involved_idx
]),
dim
=
0
).
data
for
i
in
range
(
2
):
for
j
in
range
(
2
):
origin_i
=
involved_idx
[
i
]
origin_j
=
involved_idx
[
j
]
self
.
ap_path_alpha
.
grad
.
data
[
origin_i
]
+=
\
binary_grads
[
origin_j
]
*
probs_slice
[
j
]
*
(
MixedOp
.
delta_ij
(
i
,
j
)
-
probs_slice
[
i
])
for
_i
,
idx
in
enumerate
(
self
.
active_index
):
self
.
active_index
[
_i
]
=
(
idx
,
self
.
ap_path_alpha
.
data
[
idx
].
item
())
for
_i
,
idx
in
enumerate
(
self
.
inactive_index
):
self
.
inactive_index
[
_i
]
=
(
idx
,
self
.
ap_path_alpha
.
data
[
idx
].
item
())
else
:
probs
=
self
.
probs_over_ops
.
data
for
i
in
range
(
self
.
n_choices
):
for
j
in
range
(
self
.
n_choices
):
self
.
ap_path_alpha
.
grad
.
data
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
MixedOp
.
delta_ij
(
i
,
j
)
-
probs
[
i
])
return
def
rescale_updated_arch_param
(
self
):
"""
rescale architecture weights for the 'two' mode.
"""
if
not
isinstance
(
self
.
active_index
[
0
],
tuple
):
assert
self
.
active_op
.
is_zero_layer
()
return
involved_idx
=
[
idx
for
idx
,
_
in
(
self
.
active_index
+
self
.
inactive_index
)]
old_alphas
=
[
alpha
for
_
,
alpha
in
(
self
.
active_index
+
self
.
inactive_index
)]
new_alphas
=
[
self
.
ap_path_alpha
.
data
[
idx
]
for
idx
in
involved_idx
]
offset
=
math
.
log
(
sum
([
math
.
exp
(
alpha
)
for
alpha
in
new_alphas
])
/
sum
([
math
.
exp
(
alpha
)
for
alpha
in
old_alphas
])
)
for
idx
in
involved_idx
:
self
.
ap_path_alpha
.
data
[
idx
]
-=
offset
class
ProxylessNasMutator
(
BaseMutator
):
"""
This mutator initializes and operates all the LayerChoices of the input model.
It is for the corresponding trainer to control the training process of LayerChoices,
coordinating with whole training process.
"""
def
__init__
(
self
,
model
):
"""
Init a MixedOp instance for each mutable i.e., LayerChoice.
And register the instantiated MixedOp in corresponding LayerChoice.
If does not register it in LayerChoice, DataParallel does not work then,
because architecture weights are not included in the DataParallel model.
When MixedOPs are registered, we use ```requires_grad``` to control
whether calculate gradients of architecture weights.
Parameters
----------
model : pytorch model
The model that users want to tune, it includes search space defined with nni nas apis
"""
super
(
ProxylessNasMutator
,
self
).
__init__
(
model
)
self
.
_unused_modules
=
None
self
.
mutable_list
=
[]
for
mutable
in
self
.
undedup_mutables
:
self
.
mutable_list
.
append
(
mutable
)
mutable
.
registered_module
=
MixedOp
(
mutable
)
def
on_forward_layer_choice
(
self
,
mutable
,
*
args
,
**
kwargs
):
"""
Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real
implementation is defined in mutator.
Parameters
----------
mutable: LayerChoice
forward logic of this input mutable
args: list of torch.Tensor
inputs of this mutable
kwargs: dict
inputs of this mutable
Returns
-------
torch.Tensor
output of this mutable, i.e., LayerChoice
int
index of the chosen op
"""
# FIXME: return mask, to be consistent with other algorithms
idx
=
mutable
.
registered_module
.
active_op_index
return
mutable
.
registered_module
(
mutable
,
*
args
,
**
kwargs
),
idx
def
reset_binary_gates
(
self
):
"""
For each LayerChoice, binarize binary weights
based on alpha to only activate one op.
It traverses all the mutables in the model to do this.
"""
for
mutable
in
self
.
undedup_mutables
:
mutable
.
registered_module
.
binarize
(
mutable
)
def
set_chosen_op_active
(
self
):
"""
For each LayerChoice, set the op with highest alpha as the chosen op.
Usually used for validation.
"""
for
mutable
in
self
.
undedup_mutables
:
mutable
.
registered_module
.
set_chosen_op_active
()
def
num_arch_params
(
self
):
"""
The number of mutables, i.e., LayerChoice
Returns
-------
int
the number of LayerChoice in user model
"""
return
len
(
self
.
mutable_list
)
def
set_arch_param_grad
(
self
):
"""
For each LayerChoice, calculate gradients for architecture weights, i.e., alpha
"""
for
mutable
in
self
.
undedup_mutables
:
mutable
.
registered_module
.
set_arch_param_grad
(
mutable
)
def
get_architecture_parameters
(
self
):
"""
Get all the architecture parameters.
yield
-----
PyTorch Parameter
Return ap_path_alpha of the traversed mutable
"""
for
mutable
in
self
.
undedup_mutables
:
yield
mutable
.
registered_module
.
get_ap_path_alpha
()
def
change_forward_mode
(
self
,
mode
):
"""
Update forward mode of MixedOps, as training architecture weights and
model weights use different forward modes.
"""
MixedOp
.
forward_mode
=
mode
def
get_forward_mode
(
self
):
"""
Get forward mode of MixedOp
Returns
-------
string
the current forward mode of MixedOp
"""
return
MixedOp
.
forward_mode
def
rescale_updated_arch_param
(
self
):
"""
Rescale architecture weights in 'two' mode.
"""
for
mutable
in
self
.
undedup_mutables
:
mutable
.
registered_module
.
rescale_updated_arch_param
()
def
unused_modules_off
(
self
):
"""
Remove unused modules for each mutables.
The removed modules are kept in ```self._unused_modules``` for resume later.
"""
self
.
_unused_modules
=
[]
for
mutable
in
self
.
undedup_mutables
:
mixed_op
=
mutable
.
registered_module
unused
=
{}
if
self
.
get_forward_mode
()
in
[
'full'
,
'two'
,
'full_v2'
]:
involved_index
=
mixed_op
.
active_index
+
mixed_op
.
inactive_index
else
:
involved_index
=
mixed_op
.
active_index
for
i
in
range
(
mixed_op
.
n_choices
):
if
i
not
in
involved_index
:
unused
[
i
]
=
mutable
[
i
]
mutable
[
i
]
=
None
self
.
_unused_modules
.
append
(
unused
)
def
unused_modules_back
(
self
):
"""
Resume the removed modules back.
"""
if
self
.
_unused_modules
is
None
:
return
for
m
,
unused
in
zip
(
self
.
mutable_list
,
self
.
_unused_modules
):
for
i
in
unused
:
m
[
i
]
=
unused
[
i
]
self
.
_unused_modules
=
None
def
arch_requires_grad
(
self
):
"""
Make architecture weights require gradient
"""
for
mutable
in
self
.
undedup_mutables
:
mutable
.
registered_module
.
to_requires_grad
()
def
arch_disable_grad
(
self
):
"""
Disable gradient of architecture weights, i.e., does not
calcuate gradient for them.
"""
for
mutable
in
self
.
undedup_mutables
:
mutable
.
registered_module
.
to_disable_grad
()
def
sample_final
(
self
):
"""
Generate the final chosen architecture.
Returns
-------
dict
the choice of each mutable, i.e., LayerChoice
"""
result
=
dict
()
for
mutable
in
self
.
undedup_mutables
:
assert
isinstance
(
mutable
,
LayerChoice
)
index
,
_
=
mutable
.
registered_module
.
chosen_index
# pylint: disable=not-callable
result
[
mutable
.
key
]
=
F
.
one_hot
(
torch
.
tensor
(
index
),
num_classes
=
len
(
mutable
)).
view
(
-
1
).
bool
()
return
result
nni/algorithms/nas/pytorch/proxylessnas/trainer.py
deleted
100644 → 0
View file @
d6dcb483
This diff is collapsed.
Click to expand it.
nni/algorithms/nas/pytorch/proxylessnas/utils.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
import
torch.nn
as
nn
def
detach_variable
(
inputs
):
"""
Detach variables
Parameters
----------
inputs : pytorch tensors
pytorch tensors
"""
if
isinstance
(
inputs
,
tuple
):
return
tuple
([
detach_variable
(
x
)
for
x
in
inputs
])
else
:
x
=
inputs
.
detach
()
x
.
requires_grad
=
inputs
.
requires_grad
return
x
def
cross_entropy_with_label_smoothing
(
pred
,
target
,
label_smoothing
=
0.1
):
"""
Parameters
----------
pred : pytorch tensor
predicted value
target : pytorch tensor
label
label_smoothing : float
the degree of label smoothing
Returns
-------
pytorch tensor
cross entropy
"""
logsoftmax
=
nn
.
LogSoftmax
()
n_classes
=
pred
.
size
(
1
)
# convert to one-hot
target
=
torch
.
unsqueeze
(
target
,
1
)
soft_target
=
torch
.
zeros_like
(
pred
)
soft_target
.
scatter_
(
1
,
target
,
1
)
# label smoothing
soft_target
=
soft_target
*
(
1
-
label_smoothing
)
+
label_smoothing
/
n_classes
return
torch
.
mean
(
torch
.
sum
(
-
soft_target
*
logsoftmax
(
pred
),
1
))
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
"""
Computes the precision@k for the specified values of k
Parameters
----------
output : pytorch tensor
output, e.g., predicted value
target : pytorch tensor
label
topk : tuple
specify top1 and top5
Returns
-------
list
accuracy of top1 and top5
"""
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:
k
].
reshape
(
-
1
).
float
().
sum
(
0
,
keepdim
=
True
)
res
.
append
(
correct_k
.
mul_
(
100.0
/
batch_size
))
return
res
nni/algorithms/nas/pytorch/random/mutator.py
deleted
100644 → 0
View file @
d6dcb483
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
import
torch.nn.functional
as
F
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
class
RandomMutator
(
Mutator
):
"""
Random mutator that samples a random candidate in the search space each time ``reset()``.
It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior.
"""
def
sample_search
(
self
):
"""
Sample a random candidate.
"""
result
=
dict
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
gen_index
=
torch
.
randint
(
high
=
len
(
mutable
),
size
=
(
1
,
))
result
[
mutable
.
key
]
=
F
.
one_hot
(
gen_index
,
num_classes
=
len
(
mutable
)).
view
(
-
1
).
bool
()
elif
isinstance
(
mutable
,
InputChoice
):
if
mutable
.
n_chosen
is
None
:
result
[
mutable
.
key
]
=
torch
.
randint
(
high
=
2
,
size
=
(
mutable
.
n_candidates
,)).
view
(
-
1
).
bool
()
else
:
perm
=
torch
.
randperm
(
mutable
.
n_candidates
)
mask
=
[
i
in
perm
[:
mutable
.
n_chosen
]
for
i
in
range
(
mutable
.
n_candidates
)]
result
[
mutable
.
key
]
=
torch
.
tensor
(
mask
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
return
result
def
sample_final
(
self
):
"""
Same as :meth:`sample_search`.
"""
return
self
.
sample_search
()
Prev
1
2
3
4
5
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment