Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
1cada380
Commit
1cada380
authored
Nov 18, 2019
by
Yuge Zhang
Committed by
QuanluZhang
Nov 18, 2019
Browse files
Extract base mutator/trainer and support ENAS micro search space (#1739)
parent
3ddab980
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
187 additions
and
126 deletions
+187
-126
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+58
-0
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+37
-21
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+30
-85
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
+3
-3
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+59
-6
src/sdk/pynni/nni/nas/utils.py
src/sdk/pynni/nni/nas/utils.py
+0
-11
No files found.
src/sdk/pynni/nni/nas/pytorch/fixed.py
0 → 100644
View file @
1cada380
import
json
import
torch
from
nni.nas.pytorch.mutator
import
Mutator
class
FixedArchitecture
(
Mutator
):
def
__init__
(
self
,
model
,
fixed_arc
,
strict
=
True
):
"""
Initialize a fixed architecture mutator.
Parameters
----------
model: nn.Module
A mutable network.
fixed_arc: str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict: bool
Force everything that appears in `fixed_arc` to be used at least once.
"""
super
().
__init__
(
model
)
if
isinstance
(
fixed_arc
,
str
):
with
open
(
fixed_arc
,
"r"
)
as
f
:
fixed_arc
=
json
.
load
(
f
.
read
())
self
.
_fixed_arc
=
fixed_arc
self
.
_strict
=
strict
def
_encode_tensor
(
self
,
data
):
if
isinstance
(
data
,
list
):
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
else
:
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
)
# pylint: disable=not-callable
if
isinstance
(
data
,
dict
):
return
{
k
:
self
.
_encode_tensor
(
v
)
for
k
,
v
in
data
.
items
()}
return
data
def
before_pass
(
self
):
self
.
_unused_key
=
set
(
self
.
_fixed_arc
.
keys
())
def
after_pass
(
self
):
if
self
.
_strict
:
if
self
.
_unused_key
:
raise
ValueError
(
"{} are never used by the network. "
"Set strict=False if you want to disable this check."
.
format
(
self
.
_unused_key
))
def
_check_key
(
self
,
key
):
if
key
not
in
self
.
_fixed_arc
:
raise
ValueError
(
"
\"
{}
\"
is demanded by the network, but not found in saved architecture."
.
format
(
key
))
def
on_calc_layer_choice_mask
(
self
,
mutable
):
self
.
_check_key
(
mutable
.
key
)
return
self
.
_fixed_arc
[
mutable
.
key
]
def
on_calc_input_choice_mask
(
self
,
mutable
,
tags
):
self
.
_check_key
(
mutable
.
key
)
return
self
.
_fixed_arc
[
mutable
.
key
]
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
1cada380
...
...
@@ -3,7 +3,7 @@ import torch.nn as nn
from
nni.nas.utils
import
global_mutable_counting
class
PyTorch
Mutable
(
nn
.
Module
):
class
Mutable
(
nn
.
Module
):
"""
Mutable is designed to function as a normal layer, with all necessary operators' weights.
States and weights of architectures should be included in mutator, instead of the layer itself.
...
...
@@ -24,15 +24,11 @@ class PyTorchMutable(nn.Module):
self
.
_key
=
key
else
:
self
.
_key
=
self
.
__class__
.
__name__
+
str
(
global_mutable_counting
())
self
.
name
=
self
.
key
self
.
init_hook
=
self
.
forward_hook
=
None
def
__deepcopy__
(
self
,
memodict
=
None
):
raise
NotImplementedError
(
"Deep copy doesn't work for mutables."
)
def
__enter__
(
self
):
self
.
_check_built
()
return
super
().
__enter__
()
def
__call__
(
self
,
*
args
,
**
kwargs
):
self
.
_check_built
()
return
super
().
__call__
(
*
args
,
**
kwargs
)
...
...
@@ -47,8 +43,16 @@ class PyTorchMutable(nn.Module):
def
key
(
self
):
return
self
.
_key
@
property
def
name
(
self
):
return
self
.
_name
if
hasattr
(
self
,
"_name"
)
else
"_key"
@
name
.
setter
def
name
(
self
,
name
):
self
.
_name
=
name
def
similar
(
self
,
other
):
return
self
==
other
return
type
(
self
)
==
type
(
other
)
def
_check_built
(
self
):
if
not
hasattr
(
self
,
"mutator"
):
...
...
@@ -56,8 +60,11 @@ class PyTorchMutable(nn.Module):
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
def
__repr__
(
self
):
return
"{} ({})"
.
format
(
self
.
name
,
self
.
key
)
class
MutableScope
(
PyTorchMutable
):
class
MutableScope
(
Mutable
):
"""
Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope
is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch
...
...
@@ -67,14 +74,18 @@ class MutableScope(PyTorchMutable):
def
__init__
(
self
,
key
):
super
().
__init__
(
key
=
key
)
def
__enter__
(
self
):
self
.
mutator
.
enter
_mutable_scope
(
self
)
def
build
(
self
):
self
.
mutator
.
on_init
_mutable_scope
(
self
)
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
mutator
.
exit_mutable_scope
(
self
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
try
:
self
.
mutator
.
enter_mutable_scope
(
self
)
return
super
().
__call__
(
*
args
,
**
kwargs
)
finally
:
self
.
mutator
.
exit_mutable_scope
(
self
)
class
LayerChoice
(
PyTorch
Mutable
):
class
LayerChoice
(
Mutable
):
def
__init__
(
self
,
op_candidates
,
reduction
=
"mean"
,
return_mask
=
False
,
key
=
None
):
super
().
__init__
(
key
=
key
)
self
.
length
=
len
(
op_candidates
)
...
...
@@ -83,10 +94,10 @@ class LayerChoice(PyTorchMutable):
self
.
return_mask
=
return_mask
def
__len__
(
self
):
return
self
.
length
return
len
(
self
.
choices
)
def
forward
(
self
,
*
inputs
):
out
,
mask
=
self
.
mutator
.
on_forward
(
self
,
*
inputs
)
out
,
mask
=
self
.
mutator
.
on_forward
_layer_choice
(
self
,
*
inputs
)
if
self
.
return_mask
:
return
out
,
mask
return
out
...
...
@@ -95,7 +106,7 @@ class LayerChoice(PyTorchMutable):
return
type
(
self
)
==
type
(
other
)
and
self
.
length
==
other
.
length
class
InputChoice
(
PyTorch
Mutable
):
class
InputChoice
(
Mutable
):
def
__init__
(
self
,
n_candidates
,
n_selected
=
None
,
reduction
=
"mean"
,
return_mask
=
False
,
key
=
None
):
super
().
__init__
(
key
=
key
)
assert
n_candidates
>
0
,
"Number of candidates must be greater than 0."
...
...
@@ -104,16 +115,21 @@ class InputChoice(PyTorchMutable):
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
def
forward
(
self
,
optional_inputs
,
semantic_labels
=
None
):
def
build
(
self
):
self
.
mutator
.
on_init_input_choice
(
self
)
def
forward
(
self
,
optional_inputs
,
tags
=
None
):
assert
len
(
optional_inputs
)
==
self
.
n_candidates
,
\
"Length of the input list must be equal to number of candidates."
if
semantic_labels
is
None
:
semantic_labels
=
[
"default_label"
]
*
self
.
n_candidates
out
,
mask
=
self
.
mutator
.
on_forward
(
self
,
optional_inputs
,
semantic_labels
)
if
tags
is
None
:
tags
=
[
""
]
*
self
.
n_candidates
else
:
assert
len
(
tags
)
==
self
.
n_candidates
,
"Length of tags must be equal to number of candidates."
out
,
mask
=
self
.
mutator
.
on_forward_input_choice
(
self
,
optional_inputs
,
tags
)
if
self
.
return_mask
:
return
out
,
mask
return
out
def
similar
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
\
self
.
n_candidates
==
other
.
n_candidates
and
self
.
n_selected
and
other
.
n_selected
self
.
n_candidates
==
other
.
n_candidates
and
self
.
n_selected
and
other
.
n_selected
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
1cada380
import
logging
from
contextlib
import
contextmanager
import
torch
import
torch.nn
as
nn
from
nni.nas.pytorch.mutables
import
PyTorchMutable
from
nni.nas.utils
import
to_snake_case
from
nni.nas.pytorch.base_mutator
import
BaseMutator
logger
=
logging
.
getLogger
(
__name__
)
class
Mutator
(
BaseMutator
,
nn
.
Module
):
class
PyTorchMutator
(
nn
.
Module
):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
before_build
(
model
)
self
.
parse_search_space
(
model
)
self
.
after_build
(
model
)
def
before_build
(
self
,
model
):
pass
def
after_build
(
self
,
model
):
pass
def
named_mutables
(
self
,
model
):
# if distinct is true, the method will filter out those with duplicated keys
key2module
=
dict
()
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
PyTorchMutable
):
distinct
=
False
if
module
.
key
in
key2module
:
assert
key2module
[
module
.
key
].
similar
(
module
),
\
"Mutable
\"
{}
\"
that share the same key must be similar to each other"
.
format
(
module
.
key
)
else
:
distinct
=
True
key2module
[
module
.
key
]
=
module
yield
name
,
module
,
distinct
def
__setattr__
(
self
,
key
,
value
):
if
key
in
[
"model"
,
"net"
,
"network"
]:
logger
.
warning
(
"Think twice if you are including the network into mutator."
)
return
super
().
__setattr__
(
key
,
value
)
def
parse_search_space
(
self
,
model
):
for
name
,
mutable
,
distinct
in
self
.
named_mutables
(
model
):
mutable
.
name
=
name
mutable
.
set_mutator
(
self
)
if
not
distinct
:
continue
init_method_name
=
"on_init_{}"
.
format
(
to_snake_case
(
mutable
.
__class__
.
__name__
))
if
hasattr
(
self
,
init_method_name
)
and
callable
(
getattr
(
self
,
init_method_name
)):
getattr
(
self
,
init_method_name
)(
mutable
)
else
:
# fallback to general init
self
.
on_init_general
(
mutable
)
def
on_init_general
(
self
,
mutable
):
pass
def
export
(
self
):
if
self
.
_in_forward_pass
:
raise
RuntimeError
(
"Still in forward pass. Exporting might induce incompleteness."
)
if
not
self
.
_cache
:
raise
RuntimeError
(
"No running history found. You need to call your model at least once before exporting. "
"You might also want to check if there are no valid mutables in your model."
)
return
self
.
_cache
@
contextmanager
def
forward_pass
(
self
):
self
.
_in_forward_pass
=
True
self
.
_cache
=
dict
()
self
.
before_pass
()
try
:
yield
self
finally
:
self
.
after_pass
()
self
.
_in_forward_pass
=
False
def
before_pass
(
self
):
self
.
_in_forward_pass
=
True
self
.
_cache
=
dict
()
def
after_pass
(
self
):
self
.
_in_forward_pass
=
False
def
enter_mutable_scope
(
self
,
mutable_scope
):
pass
def
exit_mutable_scope
(
self
,
mutable_scope
):
def
after_pass
(
self
):
pass
def
forward
(
self
,
*
inputs
):
raise
NotImplementedError
(
"Mutator is not forward-able"
)
def
on_forward
(
self
,
mutable
,
*
inputs
):
"""Callback on forwarding a mutable"""
def
_check_in_forward_pass
(
self
):
if
not
hasattr
(
self
,
"_in_forward_pass"
)
or
not
self
.
_in_forward_pass
:
raise
ValueError
(
"Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call "
"super().before_pass() and after_pass() in your override method?"
)
forward_method_name
=
"on_forward_{}"
.
format
(
to_snake_case
(
mutable
.
__class__
.
__name__
))
if
hasattr
(
self
,
forward_method_name
)
and
callable
(
getattr
(
self
,
forward_method_name
)):
return
getattr
(
self
,
forward_method_name
)(
mutable
,
*
inputs
)
else
:
# fallback to general forward
return
self
.
on_forward_general
(
mutable
,
*
inputs
)
def
on_forward_general
(
self
,
mutable
,
*
inputs
):
raise
NotImplementedError
(
"Forward has to be implemented"
)
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
"""
Callback of layer choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy spe
i
cified
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
Parameters
...
...
@@ -111,33 +52,38 @@ class PyTorchMutator(nn.Module):
Returns
-------
torch.Tensor
tuple of torch.Tensor and
torch.Tensor
"""
self
.
_check_in_forward_pass
()
def
_map_fn
(
op
,
*
inputs
):
return
op
(
*
inputs
)
mask
=
self
.
_cache
.
setdefault
(
mutable
.
key
,
self
.
on_calc_layer_choice_mask
(
mutable
))
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
*
inputs
)
for
choice
in
mutable
.
choices
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
,
semantic_label
s
):
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
,
tag
s
):
"""
Callback of input choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_input_choice_mask` with `
semantic_label
s`
On default, this method calls :meth:`on_calc_input_choice_mask` with `
tag
s`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy spe
i
cified in `mutable.reduction`. It will also cache the
the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`.
Parameters
----------
mutable: InputChoice
inputs: list of torch.Tensor
tensor_list: list of torch.Tensor
tags: list of string
Returns
-------
torch.Tensor
tuple of torch.Tensor and
torch.Tensor
"""
mask
=
self
.
_cache
.
setdefault
(
mutable
.
key
,
self
.
on_calc_input_choice_mask
(
mutable
,
semantic_labels
))
out
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,
)
for
t
in
tensor_list
],
mask
)
self
.
_check_in_forward_pass
()
mask
=
self
.
_cache
.
setdefault
(
mutable
.
key
,
self
.
on_calc_input_choice_mask
(
mutable
,
tags
))
out
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,)
for
t
in
tensor_list
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_calc_layer_choice_mask
(
self
,
mutable
):
...
...
@@ -157,7 +103,7 @@ class PyTorchMutator(nn.Module):
"""
raise
NotImplementedError
(
"Layer choice mask calculation must be implemented"
)
def
on_calc_input_choice_mask
(
self
,
mutable
,
semantic_label
s
):
def
on_calc_input_choice_mask
(
self
,
mutable
,
tag
s
):
"""
Recommended to override. Calculate a mask tensor for a input choice.
...
...
@@ -165,7 +111,7 @@ class PyTorchMutator(nn.Module):
----------
mutable: InputChoice
Corresponding input choice object.
semantic_label
s: list of string
tag
s: list of string
The name of labels of input tensors given by user. Usually it's a
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
...
...
@@ -179,7 +125,6 @@ class PyTorchMutator(nn.Module):
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
if
"BoolTensor"
in
mask
.
type
():
# print(candidates[0], len(mask))
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
elif
"FloatTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)]
...
...
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
View file @
1cada380
...
...
@@ -33,13 +33,13 @@ class PdartsTrainer(Trainer):
for
epoch
in
range
(
self
.
pdarts_epoch
):
layers
=
self
.
layers
+
self
.
pdarts_num_layers
[
epoch
]
model
,
loss
,
model_optim
,
lr_scheduler
=
self
.
model_creator
(
model
,
loss
,
model_optim
,
_
=
self
.
model_creator
(
layers
,
n_nodes
)
mutator
=
PdartsMutator
(
model
,
epoch
,
self
.
pdarts_num_to_drop
,
switches
)
self
.
trainer
=
DartsTrainer
(
model
,
loss
=
loss
,
model_
optim
=
model_optim
,
lr_scheduler
=
lr_scheduler
,
mutator
=
mutator
,
**
self
.
darts_parameters
)
self
.
trainer
=
DartsTrainer
(
model
,
loss
=
loss
,
optim
izer
=
model_optim
,
mutator
=
mutator
,
**
self
.
darts_parameters
)
print
(
"start pdrats training %s..."
%
epoch
)
self
.
trainer
.
train
()
...
...
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
1cada380
from
abc
import
ABC
,
abstractmethod
from
abc
import
abstractmethod
import
torch
class
Trainer
(
ABC
):
from
.base_trainer
import
BaseTrainer
class
Trainer
(
BaseTrainer
):
def
__init__
(
self
,
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
mutator
,
callbacks
):
self
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
if
device
is
None
else
device
self
.
model
=
model
self
.
loss
=
loss
self
.
metrics
=
metrics
self
.
optimizer
=
optimizer
self
.
mutator
=
mutator
self
.
model
.
to
(
self
.
device
)
self
.
loss
.
to
(
self
.
device
)
self
.
mutator
.
to
(
self
.
device
)
self
.
num_epochs
=
num_epochs
self
.
dataset_train
=
dataset_train
self
.
dataset_valid
=
dataset_valid
self
.
batch_size
=
batch_size
self
.
workers
=
workers
self
.
log_frequency
=
log_frequency
self
.
callbacks
=
callbacks
if
callbacks
is
not
None
else
[]
for
callback
in
self
.
callbacks
:
callback
.
build
(
self
.
model
,
self
.
mutator
,
self
)
@
abstractmethod
def
train
(
self
):
raise
NotImplementedError
def
train
_one_epoch
(
self
,
epoch
):
pass
@
abstractmethod
def
export
(
self
):
raise
NotImplementedError
def
validate_one_epoch
(
self
,
epoch
):
pass
def
_train
(
self
,
validate
):
for
epoch
in
range
(
self
.
num_epochs
):
for
callback
in
self
.
callbacks
:
callback
.
on_epoch_begin
(
epoch
)
# training
print
(
"Epoch {} Training"
.
format
(
epoch
))
self
.
train_one_epoch
(
epoch
)
if
validate
:
# validation
print
(
"Epoch {} Validating"
.
format
(
epoch
))
self
.
validate_one_epoch
(
epoch
)
for
callback
in
self
.
callbacks
:
callback
.
on_epoch_end
(
epoch
)
def
train_and_validate
(
self
):
self
.
_train
(
True
)
def
train
(
self
):
self
.
_train
(
False
)
def
validate
(
self
):
self
.
validate_one_epoch
(
-
1
)
src/sdk/pynni/nni/nas/utils.py
View file @
1cada380
import
re
from
collections
import
OrderedDict
import
torch
_counter
=
0
...
...
@@ -12,14 +9,6 @@ def global_mutable_counting():
return
_counter
def
to_snake_case
(
camel_case
):
return
re
.
sub
(
'(?!^)([A-Z]+)'
,
r
'_\1'
,
camel_case
).
lower
()
def
auto_device
():
return
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
class
AverageMeterGroup
(
object
):
def
__init__
(
self
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment