Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a7b96de9
Unverified
Commit
a7b96de9
authored
Apr 03, 2020
by
Yuge Zhang
Committed by
GitHub
Apr 03, 2020
Browse files
Enable visualization in examples (#2261)
parent
63bd0f50
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
34 additions
and
0 deletions
+34
-0
examples/nas/darts/search.py
examples/nas/darts/search.py
+3
-0
examples/nas/enas/search.py
examples/nas/enas/search.py
+3
-0
examples/nas/naive/train.py
examples/nas/naive/train.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
+2
-0
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+24
-0
No files found.
examples/nas/darts/search.py
View file @
a7b96de9
...
...
@@ -24,6 +24,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--epochs"
,
default
=
50
,
type
=
int
)
parser
.
add_argument
(
"--channels"
,
default
=
16
,
type
=
int
)
parser
.
add_argument
(
"--unrolled"
,
default
=
False
,
action
=
"store_true"
)
parser
.
add_argument
(
"--visualization"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
...
...
@@ -45,4 +46,6 @@ if __name__ == "__main__":
log_frequency
=
args
.
log_frequency
,
unrolled
=
args
.
unrolled
,
callbacks
=
[
LRSchedulerCallback
(
lr_scheduler
),
ArchitectureCheckpoint
(
"./checkpoints"
)])
if
args
.
visualization
:
trainer
.
enable_visualization
()
trainer
.
train
()
examples/nas/enas/search.py
View file @
a7b96de9
...
...
@@ -25,6 +25,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--search-for"
,
choices
=
[
"macro"
,
"micro"
],
default
=
"macro"
)
parser
.
add_argument
(
"--epochs"
,
default
=
None
,
type
=
int
,
help
=
"Number of epochs (default: macro 310, micro 150)"
)
parser
.
add_argument
(
"--visualization"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
...
...
@@ -55,4 +56,6 @@ if __name__ == "__main__":
dataset_valid
=
dataset_valid
,
log_frequency
=
args
.
log_frequency
,
mutator
=
mutator
)
if
args
.
visualization
:
trainer
.
enable_visualization
()
trainer
.
train
()
examples/nas/naive/train.py
View file @
a7b96de9
...
...
@@ -68,5 +68,6 @@ if __name__ == "__main__":
dataset_valid
=
dataset_valid
,
batch_size
=
64
,
log_frequency
=
10
)
trainer
.
enable_visualization
()
trainer
.
train
()
trainer
.
export
(
"checkpoint.json"
)
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
a7b96de9
...
...
@@ -129,6 +129,7 @@ class DartsTrainer(Trainer):
self
.
mutator
.
reset
()
logits
=
self
.
model
(
X
)
loss
=
self
.
loss
(
logits
,
y
)
self
.
_write_graph_status
()
return
logits
,
loss
def
_backward
(
self
,
val_X
,
val_y
):
...
...
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
View file @
a7b96de9
...
...
@@ -126,6 +126,7 @@ class EnasTrainer(Trainer):
with
torch
.
no_grad
():
self
.
mutator
.
reset
()
self
.
_write_graph_status
()
logits
=
self
.
model
(
x
)
if
isinstance
(
logits
,
tuple
):
...
...
@@ -159,6 +160,7 @@ class EnasTrainer(Trainer):
self
.
mutator
.
reset
()
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
self
.
_write_graph_status
()
metrics
=
self
.
metrics
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
if
self
.
entropy_weight
:
...
...
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
a7b96de9
...
...
@@ -3,6 +3,8 @@
import
json
import
logging
import
os
import
time
from
abc
import
abstractmethod
import
torch
...
...
@@ -90,6 +92,9 @@ class Trainer(BaseTrainer):
self
.
batch_size
=
batch_size
self
.
workers
=
workers
self
.
log_frequency
=
log_frequency
self
.
log_dir
=
os
.
path
.
join
(
"logs"
,
str
(
time
.
time
()))
os
.
makedirs
(
self
.
log_dir
,
exist_ok
=
True
)
self
.
status_writer
=
open
(
os
.
path
.
join
(
self
.
log_dir
,
"log"
),
"w"
)
self
.
callbacks
=
callbacks
if
callbacks
is
not
None
else
[]
for
callback
in
self
.
callbacks
:
callback
.
build
(
self
.
model
,
self
.
mutator
,
self
)
...
...
@@ -168,3 +173,22 @@ class Trainer(BaseTrainer):
Return trainer checkpoint.
"""
raise
NotImplementedError
(
"Not implemented yet"
)
def
enable_visualization
(
self
):
"""
Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
"""
sample
=
None
for
x
,
_
in
self
.
train_loader
:
sample
=
x
.
to
(
self
.
device
)[:
2
]
break
if
sample
is
None
:
_logger
.
warning
(
"Sample is %s."
,
sample
)
_logger
.
info
(
"Creating graph json, writing to %s. Visualization enabled."
,
self
.
log_dir
)
with
open
(
os
.
path
.
join
(
self
.
log_dir
,
"graph.json"
),
"w"
)
as
f
:
json
.
dump
(
self
.
mutator
.
graph
(
sample
),
f
)
self
.
visualization_enabled
=
True
def
_write_graph_status
(
self
):
if
hasattr
(
self
,
"visualization_enabled"
)
and
self
.
visualization_enabled
:
print
(
json
.
dumps
(
self
.
mutator
.
status
()),
file
=
self
.
status_writer
,
flush
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment