Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d43fbe82
Commit
d43fbe82
authored
Nov 11, 2019
by
quzha
Browse files
Merge branch 'dev-nas-refactor' of github.com:Microsoft/nni into dev-nas-refactor
parents
0e3906aa
bb797e10
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
640 additions
and
0 deletions
+640
-0
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
+126
-0
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
+120
-0
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+119
-0
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+203
-0
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+12
-0
src/sdk/pynni/nni/nas/tensorflow/__init__.py
src/sdk/pynni/nni/nas/tensorflow/__init__.py
+0
-0
src/sdk/pynni/nni/nas/utils.py
src/sdk/pynni/nni/nas/utils.py
+60
-0
No files found.
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
0 → 100644
View file @
d43fbe82
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.nas.pytorch.mutator
import
PyTorchMutator
class
StackedLSTMCell
(
nn
.
Module
):
def
__init__
(
self
,
layers
,
size
,
bias
):
super
().
__init__
()
self
.
lstm_num_layers
=
layers
self
.
lstm_modules
=
nn
.
ModuleList
([
nn
.
LSTMCell
(
size
,
size
,
bias
=
bias
)
for
_
in
range
(
self
.
lstm_num_layers
)])
def
forward
(
self
,
inputs
,
hidden
):
prev_c
,
prev_h
=
hidden
next_c
,
next_h
=
[],
[]
for
i
,
m
in
enumerate
(
self
.
lstm_modules
):
curr_c
,
curr_h
=
m
(
inputs
,
(
prev_c
[
i
],
prev_h
[
i
]))
next_c
.
append
(
curr_c
)
next_h
.
append
(
curr_h
)
inputs
=
curr_h
[
-
1
]
return
next_c
,
next_h
class
EnasMutator
(
PyTorchMutator
):
def
__init__
(
self
,
model
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
anchor_extra_step
=
False
,
skip_target
=
0.4
):
self
.
lstm_size
=
lstm_size
self
.
lstm_num_layers
=
lstm_num_layers
self
.
tanh_constant
=
tanh_constant
self
.
max_layer_choice
=
0
self
.
anchor_extra_step
=
anchor_extra_step
self
.
skip_target
=
skip_target
super
().
__init__
(
model
)
def
before_build
(
self
,
model
):
self
.
lstm
=
StackedLSTMCell
(
self
.
lstm_num_layers
,
self
.
lstm_size
,
False
)
self
.
attn_anchor
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
tensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
requires_grad
=
False
)
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
()
def
after_build
(
self
,
model
):
self
.
embedding
=
nn
.
Embedding
(
self
.
max_layer_choice
+
1
,
self
.
lstm_size
)
self
.
soft
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
max_layer_choice
)
def
before_pass
(
self
):
super
().
before_pass
()
self
.
_anchors_hid
=
dict
()
self
.
_selected_layers
=
[]
self
.
_selected_inputs
=
[]
self
.
_inputs
=
self
.
g_emb
.
data
self
.
_c
=
[
torch
.
zeros
((
1
,
self
.
lstm_size
),
dtype
=
self
.
_inputs
.
dtype
,
device
=
self
.
_inputs
.
device
)
for
_
in
range
(
self
.
lstm_num_layers
)]
self
.
_h
=
[
torch
.
zeros
((
1
,
self
.
lstm_size
),
dtype
=
self
.
_inputs
.
dtype
,
device
=
self
.
_inputs
.
device
)
for
_
in
range
(
self
.
lstm_num_layers
)]
self
.
sample_log_prob
=
0
self
.
sample_entropy
=
0
self
.
sample_skip_penalty
=
0
def
_lstm_next_step
(
self
):
self
.
_c
,
self
.
_h
=
self
.
lstm
(
self
.
_inputs
,
(
self
.
_c
,
self
.
_h
))
def
_mark_anchor
(
self
,
key
):
self
.
_anchors_hid
[
key
]
=
self
.
_h
[
-
1
]
def
on_init_layer_choice
(
self
,
mutable
):
if
self
.
max_layer_choice
==
0
:
self
.
max_layer_choice
=
mutable
.
length
assert
self
.
max_layer_choice
==
mutable
.
length
,
\
"ENAS mutator requires all layer choice have the same number of candidates."
def
on_calc_layer_choice_mask
(
self
,
mutable
):
self
.
_lstm_next_step
()
logit
=
self
.
soft
(
self
.
_h
[
-
1
])
if
self
.
tanh_constant
is
not
None
:
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
branch_id
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
branch_id
)
self
.
sample_log_prob
+=
log_prob
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
self
.
sample_entropy
+=
entropy
self
.
_inputs
=
self
.
embedding
(
branch_id
)
self
.
_selected_layers
.
append
(
branch_id
.
item
())
return
F
.
one_hot
(
branch_id
).
bool
().
view
(
-
1
)
def
on_calc_input_choice_mask
(
self
,
mutable
,
semantic_labels
):
if
mutable
.
n_selected
is
None
:
query
,
anchors
=
[],
[]
for
label
in
semantic_labels
:
if
label
not
in
self
.
_anchors_hid
:
self
.
_lstm_next_step
()
self
.
_mark_anchor
(
label
)
# empty loop, fill not found
query
.
append
(
self
.
attn_anchor
(
self
.
_anchors_hid
[
label
]))
anchors
.
append
(
self
.
_anchors_hid
[
label
])
query
=
torch
.
cat
(
query
,
0
)
query
=
torch
.
tanh
(
query
+
self
.
attn_query
(
self
.
_h
[
-
1
]))
query
=
self
.
v_attn
(
query
)
logit
=
torch
.
cat
([
-
query
,
query
],
1
)
if
self
.
tanh_constant
is
not
None
:
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
skip
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
skip_prob
=
torch
.
sigmoid
(
logit
)
kl
=
torch
.
sum
(
skip_prob
*
torch
.
log
(
skip_prob
/
self
.
skip_targets
))
self
.
sample_skip_penalty
+=
kl
log_prob
=
self
.
cross_entropy_loss
(
logit
,
skip
)
self
.
sample_log_prob
+=
torch
.
sum
(
log_prob
)
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
self
.
sample_entropy
+=
torch
.
sum
(
entropy
)
self
.
inputs
=
torch
.
matmul
(
skip
.
float
(),
torch
.
cat
(
anchors
,
0
))
/
(
1.
+
torch
.
sum
(
skip
))
self
.
_selected_inputs
.
append
(
skip
)
return
skip
.
bool
()
else
:
assert
mutable
.
n_selected
==
1
,
"Input choice must select exactly one or any in ENAS."
raise
NotImplementedError
def
exit_mutable_scope
(
self
,
mutable_scope
):
self
.
_mark_anchor
(
mutable_scope
.
key
)
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
0 → 100644
View file @
d43fbe82
import
torch
import
torch.optim
as
optim
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.utils
import
AverageMeterGroup
,
auto_device
from
.mutator
import
EnasMutator
class
EnasTrainer
(
Trainer
):
def
__init__
(
self
,
model
,
loss
,
metrics
,
reward_function
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
lr_scheduler
=
None
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
mutator_lr
=
0.00035
):
self
.
model
=
model
self
.
loss
=
loss
self
.
metrics
=
metrics
self
.
reward_function
=
reward_function
self
.
mutator
=
mutator
if
self
.
mutator
is
None
:
self
.
mutator
=
EnasMutator
(
model
)
self
.
optim
=
optimizer
self
.
mut_optim
=
optim
.
Adam
(
self
.
mutator
.
parameters
(),
lr
=
mutator_lr
)
self
.
lr_scheduler
=
lr_scheduler
self
.
num_epochs
=
num_epochs
self
.
dataset_train
=
dataset_train
self
.
dataset_valid
=
dataset_valid
self
.
device
=
auto_device
()
if
device
is
None
else
device
self
.
log_frequency
=
log_frequency
self
.
entropy_weight
=
entropy_weight
self
.
skip_weight
=
skip_weight
self
.
baseline_decay
=
baseline_decay
self
.
baseline
=
0.
self
.
model
.
to
(
self
.
device
)
self
.
loss
.
to
(
self
.
device
)
self
.
mutator
.
to
(
self
.
device
)
n_train
=
len
(
self
.
dataset_train
)
split
=
n_train
//
10
indices
=
list
(
range
(
n_train
))
train_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[:
-
split
])
valid_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[
-
split
:])
self
.
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
batch_size
=
batch_size
,
sampler
=
train_sampler
,
num_workers
=
workers
)
self
.
valid_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
batch_size
=
batch_size
,
sampler
=
valid_sampler
,
num_workers
=
workers
)
self
.
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_valid
,
batch_size
=
batch_size
,
num_workers
=
workers
)
def
train_epoch
(
self
,
epoch
):
self
.
model
.
train
()
self
.
mutator
.
train
()
for
phase
in
[
"model"
,
"mutator"
]:
if
phase
==
"model"
:
self
.
model
.
train
()
self
.
mutator
.
eval
()
else
:
self
.
model
.
eval
()
self
.
mutator
.
train
()
loader
=
self
.
train_loader
if
phase
==
"model"
else
self
.
valid_loader
meters
=
AverageMeterGroup
()
for
step
,
(
x
,
y
)
in
enumerate
(
loader
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
self
.
optim
.
zero_grad
()
self
.
mut_optim
.
zero_grad
()
with
self
.
mutator
.
forward_pass
():
logits
=
self
.
model
(
x
)
metrics
=
self
.
metrics
(
logits
,
y
)
if
phase
==
"model"
:
loss
=
self
.
loss
(
logits
,
y
)
loss
.
backward
()
self
.
optim
.
step
()
else
:
reward
=
self
.
reward_function
(
logits
,
y
)
if
self
.
entropy_weight
is
not
None
:
reward
+=
self
.
entropy_weight
*
self
.
mutator
.
sample_entropy
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
self
.
baseline
=
self
.
baseline
.
detach
().
item
()
loss
=
self
.
mutator
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
if
self
.
skip_weight
:
loss
+=
self
.
skip_weight
*
self
.
mutator
.
sample_skip_penalty
loss
.
backward
()
self
.
mut_optim
.
step
()
metrics
[
"reward"
]
=
reward
metrics
[
"loss"
]
=
loss
.
item
()
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
print
(
"Epoch {} {} Step [{}/{}] {}"
.
format
(
epoch
,
phase
.
capitalize
(),
step
,
len
(
loader
),
meters
))
# print(self.mutator._selected_layers)
# print(self.mutator._selected_inputs)
if
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
step
()
def
validate_epoch
(
self
,
epoch
):
pass
def
train
(
self
):
for
epoch
in
range
(
self
.
num_epochs
):
# training
print
(
"Epoch {} Training"
.
format
(
epoch
))
self
.
train_epoch
(
epoch
)
# validation
print
(
"Epoch {} Validating"
.
format
(
epoch
))
self
.
validate_epoch
(
epoch
)
def
export
(
self
):
pass
src/sdk/pynni/nni/nas/pytorch/mutables.py
0 → 100644
View file @
d43fbe82
import
torch.nn
as
nn
from
nni.nas.utils
import
global_mutable_counting
class
PyTorchMutable
(
nn
.
Module
):
"""
Mutable is designed to function as a normal layer, with all necessary operators' weights.
States and weights of architectures should be included in mutator, instead of the layer itself.
Mutable has a key, which marks the identity of the mutable. This key can be used by users to share
decisions among different mutables. In mutator's implementation, mutators should use the key to
distinguish different mutables. Mutables that share the same key should be "similar" to each other.
Currently the default scope for keys is global.
"""
def
__init__
(
self
,
key
=
None
):
super
().
__init__
()
if
key
is
not
None
:
if
not
isinstance
(
key
,
str
):
key
=
str
(
key
)
print
(
"Warning: key
\"
{}
\"
is not string, converted to string."
.
format
(
key
))
self
.
_key
=
key
else
:
self
.
_key
=
self
.
__class__
.
__name__
+
str
(
global_mutable_counting
())
self
.
name
=
self
.
key
def
__deepcopy__
(
self
,
memodict
=
None
):
raise
NotImplementedError
(
"Deep copy doesn't work for mutables."
)
def
__enter__
(
self
):
self
.
_check_built
()
return
super
().
__enter__
()
def
__call__
(
self
,
*
args
,
**
kwargs
):
self
.
_check_built
()
return
super
().
__call__
(
*
args
,
**
kwargs
)
def
set_mutator
(
self
,
mutator
):
self
.
__dict__
[
"mutator"
]
=
mutator
def
forward
(
self
,
*
inputs
):
raise
NotImplementedError
(
"Mutable forward must be implemented."
)
@
property
def
key
(
self
):
return
self
.
_key
def
similar
(
self
,
other
):
return
self
==
other
def
_check_built
(
self
):
if
not
hasattr
(
self
,
"mutator"
):
raise
ValueError
(
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
def
__repr__
(
self
):
return
"{} ({})"
.
format
(
self
.
name
,
self
.
key
)
class
MutableScope
(
PyTorchMutable
):
"""
Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope
is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch
corresponding events, and do status dump or update.
"""
def
__init__
(
self
,
key
):
super
().
__init__
(
key
=
key
)
def
__enter__
(
self
):
self
.
mutator
.
enter_mutable_scope
(
self
)
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
mutator
.
exit_mutable_scope
(
self
)
class
LayerChoice
(
PyTorchMutable
):
def
__init__
(
self
,
op_candidates
,
reduction
=
"mean"
,
return_mask
=
False
,
key
=
None
):
super
().
__init__
(
key
=
key
)
self
.
length
=
len
(
op_candidates
)
self
.
choices
=
nn
.
ModuleList
(
op_candidates
)
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
def
forward
(
self
,
*
inputs
):
out
,
mask
=
self
.
mutator
.
on_forward
(
self
,
*
inputs
)
if
self
.
return_mask
:
return
out
,
mask
return
out
def
similar
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
length
==
other
.
length
class
InputChoice
(
PyTorchMutable
):
def
__init__
(
self
,
n_candidates
,
n_selected
=
None
,
reduction
=
"mean"
,
return_mask
=
False
,
key
=
None
):
super
().
__init__
(
key
=
key
)
assert
n_candidates
>
0
,
"Number of candidates must be greater than 0."
self
.
n_candidates
=
n_candidates
self
.
n_selected
=
n_selected
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
def
forward
(
self
,
optional_inputs
,
semantic_labels
=
None
):
assert
len
(
optional_inputs
)
==
self
.
n_candidates
,
\
"Length of the input list must be equal to number of candidates."
if
semantic_labels
is
None
:
semantic_labels
=
[
"default_label"
]
*
self
.
n_candidates
out
,
mask
=
self
.
mutator
.
on_forward
(
self
,
optional_inputs
,
semantic_labels
)
if
self
.
return_mask
:
return
out
,
mask
return
out
def
similar
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
\
self
.
n_candidates
==
other
.
n_candidates
and
self
.
n_selected
and
other
.
n_selected
src/sdk/pynni/nni/nas/pytorch/mutator.py
0 → 100644
View file @
d43fbe82
import
logging
from
contextlib
import
contextmanager
import
torch
import
torch.nn
as
nn
from
nni.nas.pytorch.mutables
import
PyTorchMutable
from
nni.nas.utils
import
to_snake_case
logger
=
logging
.
getLogger
(
__name__
)
class
PyTorchMutator
(
nn
.
Module
):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
before_build
(
model
)
self
.
parse_search_space
(
model
)
self
.
after_build
(
model
)
def
before_build
(
self
,
model
):
pass
def
after_build
(
self
,
model
):
pass
def
named_mutables
(
self
,
model
):
# if distinct is true, the method will filter out those with duplicated keys
key2module
=
dict
()
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
PyTorchMutable
):
distinct
=
False
if
module
.
key
in
key2module
:
assert
key2module
[
module
.
key
].
similar
(
module
),
\
"Mutable
\"
{}
\"
that share the same key must be similar to each other"
.
format
(
module
.
key
)
else
:
distinct
=
True
key2module
[
module
.
key
]
=
module
yield
name
,
module
,
distinct
def
__setattr__
(
self
,
key
,
value
):
if
key
in
[
"model"
,
"net"
,
"network"
]:
logger
.
warning
(
"Think twice if you are including the network into mutator."
)
return
super
().
__setattr__
(
key
,
value
)
def
parse_search_space
(
self
,
model
):
for
name
,
mutable
,
distinct
in
self
.
named_mutables
(
model
):
mutable
.
name
=
name
mutable
.
set_mutator
(
self
)
if
not
distinct
:
continue
init_method_name
=
"on_init_{}"
.
format
(
to_snake_case
(
mutable
.
__class__
.
__name__
))
if
hasattr
(
self
,
init_method_name
)
and
callable
(
getattr
(
self
,
init_method_name
)):
getattr
(
self
,
init_method_name
)(
mutable
)
else
:
# fallback to general init
self
.
on_init_general
(
mutable
)
def
on_init_general
(
self
,
mutable
):
pass
@
contextmanager
def
forward_pass
(
self
):
self
.
before_pass
()
try
:
yield
self
finally
:
self
.
after_pass
()
def
before_pass
(
self
):
self
.
_in_forward_pass
=
True
self
.
_cache
=
dict
()
def
after_pass
(
self
):
self
.
_in_forward_pass
=
False
def
enter_mutable_scope
(
self
,
mutable_scope
):
pass
def
exit_mutable_scope
(
self
,
mutable_scope
):
pass
def
forward
(
self
,
*
inputs
):
raise
NotImplementedError
(
"Mutator is not forward-able"
)
def
on_forward
(
self
,
mutable
,
*
inputs
):
"""Callback on forwarding a mutable"""
if
not
hasattr
(
self
,
"_in_forward_pass"
)
or
not
self
.
_in_forward_pass
:
raise
ValueError
(
"Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call "
"super().before_pass() and after_pass() in your override method?"
)
forward_method_name
=
"on_forward_{}"
.
format
(
to_snake_case
(
mutable
.
__class__
.
__name__
))
if
hasattr
(
self
,
forward_method_name
)
and
callable
(
getattr
(
self
,
forward_method_name
)):
return
getattr
(
self
,
forward_method_name
)(
mutable
,
*
inputs
)
else
:
# fallback to general forward
return
self
.
on_forward_general
(
mutable
,
*
inputs
)
def
on_forward_general
(
self
,
mutable
,
*
inputs
):
raise
NotImplementedError
(
"Forward has to be implemented"
)
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
"""
Callback of layer choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy speicified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
Parameters
----------
mutable: LayerChoice
inputs: list of torch.Tensor
Returns
-------
torch.Tensor
"""
def
_map_fn
(
op
,
*
inputs
):
return
op
(
*
inputs
)
mask
=
self
.
_cache
.
setdefault
(
mutable
.
key
,
self
.
on_calc_layer_choice_mask
(
mutable
))
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
*
inputs
)
for
choice
in
mutable
.
choices
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
,
semantic_labels
):
"""
Callback of input choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_input_choice_mask` with `semantic_labels`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy speicified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`.
Parameters
----------
mutable: InputChoice
inputs: list of torch.Tensor
Returns
-------
torch.Tensor
"""
mask
=
self
.
_cache
.
setdefault
(
mutable
.
key
,
self
.
on_calc_input_choice_mask
(
mutable
,
semantic_labels
))
out
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,
)
for
t
in
tensor_list
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_calc_layer_choice_mask
(
self
,
mutable
):
"""
Recommended to override. Calculate a mask tensor for a layer choice.
Parameters
----------
mutable: LayerChoice
Corresponding layer choice object.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise
NotImplementedError
(
"Layer choice mask calculation must be implemented"
)
def
on_calc_input_choice_mask
(
self
,
mutable
,
semantic_labels
):
"""
Recommended to override. Calculate a mask tensor for a input choice.
Parameters
----------
mutable: InputChoice
Corresponding input choice object.
semantic_labels: list of string
The name of labels of input tensors given by user. Usually it's a
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise
NotImplementedError
(
"Input choice mask calculation must be implemented"
)
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
if
"BoolTensor"
in
mask
.
type
():
# print(candidates[0], len(mask))
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
elif
"FloatTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)]
else
:
raise
ValueError
(
"Unrecognized mask"
)
return
out
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
tensor_list
==
"none"
:
return
tensor_list
if
not
tensor_list
:
return
None
# empty. return None for now
if
len
(
tensor_list
)
==
1
:
return
tensor_list
[
0
]
if
reduction_type
==
"sum"
:
return
sum
(
tensor_list
)
if
reduction_type
==
"mean"
:
return
sum
(
tensor_list
)
/
len
(
tensor_list
)
if
reduction_type
==
"concat"
:
return
torch
.
cat
(
tensor_list
,
dim
=
1
)
raise
ValueError
(
"Unrecognized reduction policy:
\"
{}
\"
"
.
format
(
reduction_type
))
src/sdk/pynni/nni/nas/pytorch/trainer.py
0 → 100644
View file @
d43fbe82
from
abc
import
ABC
,
abstractmethod
class
Trainer
(
ABC
):
@
abstractmethod
def
train
(
self
):
raise
NotImplementedError
@
abstractmethod
def
export
(
self
):
raise
NotImplementedError
src/sdk/pynni/nni/nas/tensorflow/__init__.py
0 → 100644
View file @
d43fbe82
src/sdk/pynni/nni/nas/utils.py
0 → 100644
View file @
d43fbe82
import
re
from
collections
import
OrderedDict
import
torch
_counter
=
0
def
global_mutable_counting
():
global
_counter
_counter
+=
1
return
_counter
def
to_snake_case
(
camel_case
):
return
re
.
sub
(
'(?!^)([A-Z]+)'
,
r
'_\1'
,
camel_case
).
lower
()
def
auto_device
():
return
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
class
AverageMeterGroup
(
object
):
def
__init__
(
self
):
self
.
meters
=
OrderedDict
()
def
update
(
self
,
data
):
for
k
,
v
in
data
.
items
():
if
k
not
in
self
.
meters
:
self
.
meters
[
k
]
=
AverageMeter
(
k
,
":4f"
)
self
.
meters
[
k
].
update
(
v
)
def
__str__
(
self
):
return
" "
.
join
(
str
(
v
)
for
_
,
v
in
self
.
meters
.
items
())
class
AverageMeter
(
object
):
"""Computes and stores the average and current value"""
def
__init__
(
self
,
name
,
fmt
=
':f'
):
self
.
name
=
name
self
.
fmt
=
fmt
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
def
__str__
(
self
):
fmtstr
=
'{name} {val'
+
self
.
fmt
+
'} ({avg'
+
self
.
fmt
+
'})'
return
fmtstr
.
format
(
**
self
.
__dict__
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment