Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
13d03757
"git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "8ac762da586a5b90ce4da783f322a4c83213475d"
Commit
13d03757
authored
Jan 16, 2020
by
Houwen Peng
Committed by
Yuge Zhang
Jan 16, 2020
Browse files
integrate c-darts nas algorithm (#1955)
parent
a9711e24
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
351 additions
and
0 deletions
+351
-0
src/sdk/pynni/nni/nas/pytorch/cdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/cdarts/trainer.py
+275
-0
src/sdk/pynni/nni/nas/pytorch/cdarts/utils.py
src/sdk/pynni/nni/nas/pytorch/cdarts/utils.py
+76
-0
No files found.
src/sdk/pynni/nni/nas/pytorch/cdarts/trainer.py
0 → 100644
View file @
13d03757
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
logging
import
os
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
apex
# pylint: disable=import-error
from
apex.parallel
import
DistributedDataParallel
# pylint: disable=import-error
from
nni.nas.pytorch.cdarts
import
RegularizedDartsMutator
,
RegularizedMutatorParallel
,
DartsDiscreteMutator
# pylint: disable=wrong-import-order
from
nni.nas.pytorch.utils
import
AverageMeterGroup
# pylint: disable=wrong-import-order
from
.utils
import
CyclicIterator
,
TorchTensorEncoder
,
accuracy
,
reduce_metrics
PHASE_SMALL
=
"small"
PHASE_LARGE
=
"large"
class
InteractiveKLLoss
(
nn
.
Module
):
def
__init__
(
self
,
temperature
):
super
().
__init__
()
self
.
temperature
=
temperature
# self.kl_loss = nn.KLDivLoss(reduction = 'batchmean')
self
.
kl_loss
=
nn
.
KLDivLoss
()
def
forward
(
self
,
student
,
teacher
):
return
self
.
kl_loss
(
F
.
log_softmax
(
student
/
self
.
temperature
,
dim
=
1
),
F
.
softmax
(
teacher
/
self
.
temperature
,
dim
=
1
))
class
CdartsTrainer
(
object
):
def
__init__
(
self
,
model_small
,
model_large
,
criterion
,
loaders
,
samplers
,
logger
=
None
,
regular_coeff
=
5
,
regular_ratio
=
0.2
,
warmup_epochs
=
2
,
fix_head
=
True
,
epochs
=
32
,
steps_per_epoch
=
None
,
loss_alpha
=
2
,
loss_T
=
2
,
distributed
=
True
,
log_frequency
=
10
,
grad_clip
=
5.0
,
interactive_type
=
'kl'
,
output_path
=
'./outputs'
,
w_lr
=
0.2
,
w_momentum
=
0.9
,
w_weight_decay
=
3e-4
,
alpha_lr
=
0.2
,
alpha_weight_decay
=
1e-4
,
nasnet_lr
=
0.2
,
local_rank
=
0
,
share_module
=
True
):
"""
Initialize a CdartsTrainer.
Parameters
----------
model_small : nn.Module
PyTorch model to be trained. This is the search network of CDARTS.
model_large : nn.Module
PyTorch model to be trained. This is the evaluation network of CDARTS.
criterion : callable
Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
loaders : list of torch.utils.data.DataLoader
List of train data and valid data loaders, for training weights and architecture weights respectively.
samplers : list of torch.utils.data.Sampler
List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
logger : logging.Logger
The logger for logging. Will use nni logger by default (if logger is ``None``).
regular_coeff : float
The coefficient of regular loss.
regular_ratio : float
The ratio of regular loss.
warmup_epochs : int
The epochs to warmup the search network
fix_head : bool
``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
epochs : int
Number of epochs planned for training.
steps_per_epoch : int
Steps of one epoch.
loss_alpha : float
The loss coefficient.
loss_T : float
The loss coefficient.
distributed : bool
``True`` if using distributed training, else non-distributed training.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping for weights.
interactive_type : string
``kl`` or ``smoothl1``.
output_path : string
Log storage path.
w_lr : float
Learning rate of the search network parameters.
w_momentum : float
Momentum of the search and the evaluation network.
w_weight_decay : float
The weight decay the search and the evaluation network parameters.
alpha_lr : float
Learning rate of the architecture parameters.
alpha_weight_decay : float
The weight decay the architecture parameters.
nasnet_lr : float
Learning rate of the evaluation network parameters.
local_rank : int
The number of thread.
share_module : bool
``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
"""
if
logger
is
None
:
logger
=
logging
.
getLogger
(
__name__
)
train_loader
,
valid_loader
=
loaders
train_sampler
,
valid_sampler
=
samplers
self
.
train_loader
=
CyclicIterator
(
train_loader
,
train_sampler
,
distributed
)
self
.
valid_loader
=
CyclicIterator
(
valid_loader
,
valid_sampler
,
distributed
)
self
.
regular_coeff
=
regular_coeff
self
.
regular_ratio
=
regular_ratio
self
.
warmup_epochs
=
warmup_epochs
self
.
fix_head
=
fix_head
self
.
epochs
=
epochs
self
.
steps_per_epoch
=
steps_per_epoch
if
self
.
steps_per_epoch
is
None
:
self
.
steps_per_epoch
=
min
(
len
(
self
.
train_loader
),
len
(
self
.
valid_loader
))
self
.
loss_alpha
=
loss_alpha
self
.
grad_clip
=
grad_clip
if
interactive_type
==
"kl"
:
self
.
interactive_loss
=
InteractiveKLLoss
(
loss_T
)
elif
interactive_type
==
"smoothl1"
:
self
.
interactive_loss
=
nn
.
SmoothL1Loss
()
self
.
loss_T
=
loss_T
self
.
distributed
=
distributed
self
.
log_frequency
=
log_frequency
self
.
main_proc
=
not
distributed
or
local_rank
==
0
self
.
logger
=
logger
self
.
checkpoint_dir
=
output_path
if
self
.
main_proc
:
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
if
distributed
:
torch
.
distributed
.
barrier
()
self
.
model_small
=
model_small
self
.
model_large
=
model_large
if
self
.
fix_head
:
for
param
in
self
.
model_small
.
aux_head
.
parameters
():
param
.
requires_grad
=
False
for
param
in
self
.
model_large
.
aux_head
.
parameters
():
param
.
requires_grad
=
False
self
.
mutator_small
=
RegularizedDartsMutator
(
self
.
model_small
).
cuda
()
self
.
mutator_large
=
DartsDiscreteMutator
(
self
.
model_large
,
self
.
mutator_small
).
cuda
()
self
.
criterion
=
criterion
self
.
optimizer_small
=
torch
.
optim
.
SGD
(
self
.
model_small
.
parameters
(),
w_lr
,
momentum
=
w_momentum
,
weight_decay
=
w_weight_decay
)
self
.
optimizer_large
=
torch
.
optim
.
SGD
(
self
.
model_large
.
parameters
(),
nasnet_lr
,
momentum
=
w_momentum
,
weight_decay
=
w_weight_decay
)
self
.
optimizer_alpha
=
torch
.
optim
.
Adam
(
self
.
mutator_small
.
parameters
(),
alpha_lr
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
alpha_weight_decay
)
if
distributed
:
apex
.
parallel
.
convert_syncbn_model
(
self
.
model_small
)
apex
.
parallel
.
convert_syncbn_model
(
self
.
model_large
)
self
.
model_small
=
DistributedDataParallel
(
self
.
model_small
,
delay_allreduce
=
True
)
self
.
model_large
=
DistributedDataParallel
(
self
.
model_large
,
delay_allreduce
=
True
)
self
.
mutator_small
=
RegularizedMutatorParallel
(
self
.
mutator_small
,
delay_allreduce
=
True
)
if
share_module
:
self
.
model_small
.
callback_queued
=
True
self
.
model_large
.
callback_queued
=
True
# mutator large never gets optimized, so do not need parallelized
def
_warmup
(
self
,
phase
,
epoch
):
assert
phase
in
[
PHASE_SMALL
,
PHASE_LARGE
]
if
phase
==
PHASE_SMALL
:
model
,
optimizer
=
self
.
model_small
,
self
.
optimizer_small
elif
phase
==
PHASE_LARGE
:
model
,
optimizer
=
self
.
model_large
,
self
.
optimizer_large
model
.
train
()
meters
=
AverageMeterGroup
()
for
step
in
range
(
self
.
steps_per_epoch
):
x
,
y
=
next
(
self
.
train_loader
)
x
,
y
=
x
.
cuda
(),
y
.
cuda
()
optimizer
.
zero_grad
()
logits_main
,
_
=
model
(
x
)
loss
=
self
.
criterion
(
logits_main
,
y
)
loss
.
backward
()
self
.
_clip_grad_norm
(
model
)
optimizer
.
step
()
prec1
,
prec5
=
accuracy
(
logits_main
,
y
,
topk
=
(
1
,
5
))
metrics
=
{
"prec1"
:
prec1
,
"prec5"
:
prec5
,
"loss"
:
loss
}
metrics
=
reduce_metrics
(
metrics
,
self
.
distributed
)
meters
.
update
(
metrics
)
if
self
.
main_proc
and
(
step
%
self
.
log_frequency
==
0
or
step
+
1
==
self
.
steps_per_epoch
):
self
.
logger
.
info
(
"Epoch [%d/%d] Step [%d/%d] (%s) %s"
,
epoch
+
1
,
self
.
epochs
,
step
+
1
,
self
.
steps_per_epoch
,
phase
,
meters
)
def
_clip_grad_norm
(
self
,
model
):
if
isinstance
(
model
,
DistributedDataParallel
):
nn
.
utils
.
clip_grad_norm_
(
model
.
module
.
parameters
(),
self
.
grad_clip
)
else
:
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
self
.
grad_clip
)
def
_reset_nan
(
self
,
parameters
):
with
torch
.
no_grad
():
for
param
in
parameters
:
for
i
,
p
in
enumerate
(
param
):
if
p
!=
p
:
# equivalent to `isnan(p)`
param
[
i
]
=
float
(
"-inf"
)
def
_joint_train
(
self
,
epoch
):
self
.
model_large
.
train
()
self
.
model_small
.
train
()
meters
=
AverageMeterGroup
()
for
step
in
range
(
self
.
steps_per_epoch
):
trn_x
,
trn_y
=
next
(
self
.
train_loader
)
val_x
,
val_y
=
next
(
self
.
valid_loader
)
trn_x
,
trn_y
=
trn_x
.
cuda
(),
trn_y
.
cuda
()
val_x
,
val_y
=
val_x
.
cuda
(),
val_y
.
cuda
()
# step 1. optimize architecture
self
.
optimizer_alpha
.
zero_grad
()
self
.
optimizer_large
.
zero_grad
()
reg_decay
=
max
(
self
.
regular_coeff
*
(
1
-
float
(
epoch
-
self
.
warmup_epochs
)
/
(
(
self
.
epochs
-
self
.
warmup_epochs
)
*
self
.
regular_ratio
)),
0
)
loss_regular
=
self
.
mutator_small
.
reset_with_loss
()
if
loss_regular
:
loss_regular
*=
reg_decay
logits_search
,
emsemble_logits_search
=
self
.
model_small
(
val_x
)
logits_main
,
emsemble_logits_main
=
self
.
model_large
(
val_x
)
loss_cls
=
(
self
.
criterion
(
logits_search
,
val_y
)
+
self
.
criterion
(
logits_main
,
val_y
))
/
self
.
loss_alpha
loss_interactive
=
self
.
interactive_loss
(
emsemble_logits_search
,
emsemble_logits_main
)
*
(
self
.
loss_T
**
2
)
*
self
.
loss_alpha
loss
=
loss_cls
+
loss_interactive
+
loss_regular
loss
.
backward
()
self
.
_clip_grad_norm
(
self
.
model_large
)
self
.
optimizer_large
.
step
()
self
.
optimizer_alpha
.
step
()
# NOTE: need to call here `self._reset_nan(self.mutator_small.parameters())` if `cut_choices`
# step 2. optimize op weights
self
.
optimizer_small
.
zero_grad
()
with
torch
.
no_grad
():
# resample architecture since parameters have been changed
self
.
mutator_small
.
reset_with_loss
()
logits_search_train
,
_
=
self
.
model_small
(
trn_x
)
loss_weight
=
self
.
criterion
(
logits_search_train
,
trn_y
)
loss_weight
.
backward
()
self
.
_clip_grad_norm
(
self
.
model_small
)
self
.
optimizer_small
.
step
()
metrics
=
{
"loss_cls"
:
loss_cls
,
"loss_interactive"
:
loss_interactive
,
"loss_regular"
:
loss_regular
,
"loss_weight"
:
loss_weight
}
metrics
=
reduce_metrics
(
metrics
,
self
.
distributed
)
meters
.
update
(
metrics
)
if
self
.
main_proc
and
(
step
%
self
.
log_frequency
==
0
or
step
+
1
==
self
.
steps_per_epoch
):
self
.
logger
.
info
(
"Epoch [%d/%d] Step [%d/%d] (joint) %s"
,
epoch
+
1
,
self
.
epochs
,
step
+
1
,
self
.
steps_per_epoch
,
meters
)
def
train
(
self
):
for
epoch
in
range
(
self
.
epochs
):
if
epoch
<
self
.
warmup_epochs
:
with
torch
.
no_grad
():
# otherwise grads will be retained on the architecture params
self
.
mutator_small
.
reset_with_loss
()
self
.
_warmup
(
PHASE_SMALL
,
epoch
)
else
:
with
torch
.
no_grad
():
self
.
mutator_large
.
reset
()
self
.
_warmup
(
PHASE_LARGE
,
epoch
)
self
.
_joint_train
(
epoch
)
self
.
export
(
os
.
path
.
join
(
self
.
checkpoint_dir
,
"epoch_{:02d}.json"
.
format
(
epoch
)),
os
.
path
.
join
(
self
.
checkpoint_dir
,
"epoch_{:02d}.genotypes"
.
format
(
epoch
)))
def
export
(
self
,
file
,
genotype_file
):
if
self
.
main_proc
:
mutator_export
,
genotypes
=
self
.
mutator_small
.
export
(
self
.
logger
)
with
open
(
file
,
"w"
)
as
f
:
json
.
dump
(
mutator_export
,
f
,
indent
=
2
,
sort_keys
=
True
,
cls
=
TorchTensorEncoder
)
with
open
(
genotype_file
,
"w"
)
as
f
:
f
.
write
(
str
(
genotypes
))
src/sdk/pynni/nni/nas/pytorch/cdarts/utils.py
0 → 100644
View file @
13d03757
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
os
import
torch
import
torch.distributed
as
dist
class
CyclicIterator
:
def
__init__
(
self
,
loader
,
sampler
,
distributed
):
self
.
loader
=
loader
self
.
sampler
=
sampler
self
.
epoch
=
0
self
.
distributed
=
distributed
self
.
_next_epoch
()
def
_next_epoch
(
self
):
if
self
.
distributed
:
self
.
sampler
.
set_epoch
(
self
.
epoch
)
self
.
iterator
=
iter
(
self
.
loader
)
self
.
epoch
+=
1
def
__len__
(
self
):
return
len
(
self
.
loader
)
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
try
:
return
next
(
self
.
iterator
)
except
StopIteration
:
self
.
_next_epoch
()
return
next
(
self
.
iterator
)
class
TorchTensorEncoder
(
json
.
JSONEncoder
):
def
default
(
self
,
o
):
# pylint: disable=method-hidden
if
isinstance
(
o
,
torch
.
Tensor
):
return
o
.
tolist
()
return
super
().
default
(
o
)
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
""" Computes the precision@k for the specified values of k """
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
# one-hot case
if
target
.
ndimension
()
>
1
:
target
=
target
.
max
(
1
)[
1
]
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
)
res
.
append
(
correct_k
.
mul_
(
1.0
/
batch_size
))
return
res
def
reduce_tensor
(
tensor
):
rt
=
tensor
.
clone
()
dist
.
all_reduce
(
rt
,
op
=
dist
.
ReduceOp
.
SUM
)
rt
/=
float
(
os
.
environ
[
"WORLD_SIZE"
])
return
rt
def
reduce_metrics
(
metrics
,
distributed
=
False
):
if
distributed
:
return
{
k
:
reduce_tensor
(
v
).
item
()
for
k
,
v
in
metrics
.
items
()}
return
{
k
:
v
.
item
()
for
k
,
v
in
metrics
.
items
()}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment