Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
77e91e8b
"docs/en_US/vscode:/vscode.git/clone" did not exist on "0717988f2f45baeb8483536e877defd92c82cfae"
Commit
77e91e8b
authored
Nov 21, 2019
by
Yuge Zhang
Committed by
Chi Song
Nov 21, 2019
Browse files
Extract controller from mutator to make offline decisions (#1758)
parent
9dda5370
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
137 additions
and
14 deletions
+137
-14
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
+1
-2
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+29
-12
src/sdk/pynni/nni/nas/pytorch/utils.py
src/sdk/pynni/nni/nas/pytorch/utils.py
+107
-0
No files found.
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
View file @
77e91e8b
...
@@ -35,8 +35,7 @@ class PdartsTrainer(Trainer):
...
@@ -35,8 +35,7 @@ class PdartsTrainer(Trainer):
layers
=
self
.
layers
+
self
.
pdarts_num_layers
[
epoch
]
layers
=
self
.
layers
+
self
.
pdarts_num_layers
[
epoch
]
model
,
loss
,
model_optim
,
_
=
self
.
model_creator
(
model
,
loss
,
model_optim
,
_
=
self
.
model_creator
(
layers
,
n_nodes
)
layers
,
n_nodes
)
mutator
=
PdartsMutator
(
mutator
=
PdartsMutator
(
model
,
epoch
,
self
.
pdarts_num_to_drop
,
switches
)
# pylint: disable=too-many-function-args
model
,
epoch
,
self
.
pdarts_num_to_drop
,
switches
)
self
.
trainer
=
DartsTrainer
(
model
,
loss
=
loss
,
optimizer
=
model_optim
,
self
.
trainer
=
DartsTrainer
(
model
,
loss
=
loss
,
optimizer
=
model_optim
,
mutator
=
mutator
,
**
self
.
darts_parameters
)
mutator
=
mutator
,
**
self
.
darts_parameters
)
...
...
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
77e91e8b
import
json
import
logging
from
abc
import
abstractmethod
from
abc
import
abstractmethod
import
torch
import
torch
from
.base_trainer
import
BaseTrainer
from
.base_trainer
import
BaseTrainer
_logger
=
logging
.
getLogger
(
__name__
)
class
TorchTensorEncoder
(
json
.
JSONEncoder
):
def
default
(
self
,
o
):
# pylint: disable=method-hidden
if
isinstance
(
o
,
torch
.
Tensor
):
olist
=
o
.
tolist
()
if
"bool"
not
in
o
.
type
().
lower
()
and
all
(
map
(
lambda
d
:
d
==
0
or
d
==
1
,
olist
)):
_logger
.
warning
(
"Every element in %s is either 0 or 1. "
"You might consider convert it into bool."
,
olist
)
return
olist
return
super
().
default
(
o
)
class
Trainer
(
BaseTrainer
):
class
Trainer
(
BaseTrainer
):
def
__init__
(
self
,
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
def
__init__
(
self
,
model
,
mutator
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
):
mutator
,
callbacks
):
self
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
if
device
is
None
else
device
self
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
if
device
is
None
else
device
self
.
model
=
model
self
.
model
=
model
self
.
mutator
=
mutator
self
.
loss
=
loss
self
.
loss
=
loss
self
.
metrics
=
metrics
self
.
metrics
=
metrics
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
self
.
mutator
=
mutator
self
.
model
.
to
(
self
.
device
)
self
.
model
.
to
(
self
.
device
)
self
.
loss
.
to
(
self
.
device
)
self
.
mutator
.
to
(
self
.
device
)
self
.
mutator
.
to
(
self
.
device
)
self
.
loss
.
to
(
self
.
device
)
self
.
num_epochs
=
num_epochs
self
.
num_epochs
=
num_epochs
self
.
dataset_train
=
dataset_train
self
.
dataset_train
=
dataset_train
...
@@ -38,7 +53,7 @@ class Trainer(BaseTrainer):
...
@@ -38,7 +53,7 @@ class Trainer(BaseTrainer):
def
validate_one_epoch
(
self
,
epoch
):
def
validate_one_epoch
(
self
,
epoch
):
pass
pass
def
_
train
(
self
,
validate
):
def
train
(
self
,
validate
=
True
):
for
epoch
in
range
(
self
.
num_epochs
):
for
epoch
in
range
(
self
.
num_epochs
):
for
callback
in
self
.
callbacks
:
for
callback
in
self
.
callbacks
:
callback
.
on_epoch_begin
(
epoch
)
callback
.
on_epoch_begin
(
epoch
)
...
@@ -55,11 +70,13 @@ class Trainer(BaseTrainer):
...
@@ -55,11 +70,13 @@ class Trainer(BaseTrainer):
for
callback
in
self
.
callbacks
:
for
callback
in
self
.
callbacks
:
callback
.
on_epoch_end
(
epoch
)
callback
.
on_epoch_end
(
epoch
)
def
train_and_validate
(
self
):
self
.
_train
(
True
)
def
train
(
self
):
self
.
_train
(
False
)
def
validate
(
self
):
def
validate
(
self
):
self
.
validate_one_epoch
(
-
1
)
self
.
validate_one_epoch
(
-
1
)
def
export
(
self
,
file
):
mutator_export
=
self
.
mutator
.
export
()
with
open
(
file
,
"w"
)
as
f
:
json
.
dump
(
mutator_export
,
f
,
indent
=
2
,
sort_keys
=
True
,
cls
=
TorchTensorEncoder
)
def
checkpoint
(
self
):
raise
NotImplementedError
(
"Not implemented yet"
)
src/sdk/pynni/nni/nas/pytorch/utils.py
0 → 100644
View file @
77e91e8b
from
collections
import
OrderedDict
_counter
=
0
def
global_mutable_counting
():
global
_counter
_counter
+=
1
return
_counter
class
AverageMeterGroup
:
def
__init__
(
self
):
self
.
meters
=
OrderedDict
()
def
update
(
self
,
data
):
for
k
,
v
in
data
.
items
():
if
k
not
in
self
.
meters
:
self
.
meters
[
k
]
=
AverageMeter
(
k
,
":4f"
)
self
.
meters
[
k
].
update
(
v
)
def
__str__
(
self
):
return
" "
.
join
(
str
(
v
)
for
_
,
v
in
self
.
meters
.
items
())
class
AverageMeter
:
"""Computes and stores the average and current value"""
def
__init__
(
self
,
name
,
fmt
=
':f'
):
self
.
name
=
name
self
.
fmt
=
fmt
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
def
__str__
(
self
):
fmtstr
=
'{name} {val'
+
self
.
fmt
+
'} ({avg'
+
self
.
fmt
+
'})'
return
fmtstr
.
format
(
**
self
.
__dict__
)
class
StructuredMutableTreeNode
:
"""
A structured representation of a search space.
A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`.
This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet,
the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a
``Mutable`` (other than ``MutableScope``).
"""
def
__init__
(
self
,
mutable
):
self
.
mutable
=
mutable
self
.
children
=
[]
def
add_child
(
self
,
mutable
):
self
.
children
.
append
(
StructuredMutableTreeNode
(
mutable
))
return
self
.
children
[
-
1
]
def
type
(
self
):
return
type
(
self
.
mutable
)
def
__iter__
(
self
):
return
self
.
traverse
()
def
traverse
(
self
,
order
=
"pre"
,
deduplicate
=
True
,
memo
=
None
):
"""
Return a generator that generates a list of mutables in this tree.
Parameters
----------
order: str
pre or post. If pre, current mutable is yield before children. Otherwise after.
deduplicate: bool
If true, mutables with the same key will not appear after the first appearance.
memo: dict
An auxiliary variable to make deduplicate happen.
Returns
-------
generator of Mutable
"""
if
memo
is
None
:
memo
=
set
()
assert
order
in
[
"pre"
,
"post"
]
if
order
==
"pre"
:
if
self
.
mutable
is
not
None
:
if
not
deduplicate
or
self
.
mutable
.
key
not
in
memo
:
memo
.
add
(
self
.
mutable
.
key
)
yield
self
.
mutable
for
child
in
self
.
children
:
for
m
in
child
.
traverse
(
order
=
order
,
deduplicate
=
deduplicate
,
memo
=
memo
):
yield
m
if
order
==
"post"
:
if
self
.
mutable
is
not
None
:
if
not
deduplicate
or
self
.
mutable
.
key
not
in
memo
:
memo
.
add
(
self
.
mutable
.
key
)
yield
self
.
mutable
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment