Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a7b96de9
"vscode:/vscode.git/clone" did not exist on "2fb93bdb3c8f6bbe0bb653c35cbef991591646a2"
Unverified
Commit
a7b96de9
authored
Apr 03, 2020
by
Yuge Zhang
Committed by
GitHub
Apr 03, 2020
Browse files
Enable visualization in examples (#2261)
parent
63bd0f50
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
34 additions
and
0 deletions
+34
-0
examples/nas/darts/search.py
examples/nas/darts/search.py
+3
-0
examples/nas/enas/search.py
examples/nas/enas/search.py
+3
-0
examples/nas/naive/train.py
examples/nas/naive/train.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
+2
-0
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+24
-0
No files found.
examples/nas/darts/search.py
View file @
a7b96de9
...
@@ -24,6 +24,7 @@ if __name__ == "__main__":
...
@@ -24,6 +24,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--epochs"
,
default
=
50
,
type
=
int
)
parser
.
add_argument
(
"--epochs"
,
default
=
50
,
type
=
int
)
parser
.
add_argument
(
"--channels"
,
default
=
16
,
type
=
int
)
parser
.
add_argument
(
"--channels"
,
default
=
16
,
type
=
int
)
parser
.
add_argument
(
"--unrolled"
,
default
=
False
,
action
=
"store_true"
)
parser
.
add_argument
(
"--unrolled"
,
default
=
False
,
action
=
"store_true"
)
parser
.
add_argument
(
"--visualization"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
...
@@ -45,4 +46,6 @@ if __name__ == "__main__":
...
@@ -45,4 +46,6 @@ if __name__ == "__main__":
log_frequency
=
args
.
log_frequency
,
log_frequency
=
args
.
log_frequency
,
unrolled
=
args
.
unrolled
,
unrolled
=
args
.
unrolled
,
callbacks
=
[
LRSchedulerCallback
(
lr_scheduler
),
ArchitectureCheckpoint
(
"./checkpoints"
)])
callbacks
=
[
LRSchedulerCallback
(
lr_scheduler
),
ArchitectureCheckpoint
(
"./checkpoints"
)])
if
args
.
visualization
:
trainer
.
enable_visualization
()
trainer
.
train
()
trainer
.
train
()
examples/nas/enas/search.py
View file @
a7b96de9
...
@@ -25,6 +25,7 @@ if __name__ == "__main__":
...
@@ -25,6 +25,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--search-for"
,
choices
=
[
"macro"
,
"micro"
],
default
=
"macro"
)
parser
.
add_argument
(
"--search-for"
,
choices
=
[
"macro"
,
"micro"
],
default
=
"macro"
)
parser
.
add_argument
(
"--epochs"
,
default
=
None
,
type
=
int
,
help
=
"Number of epochs (default: macro 310, micro 150)"
)
parser
.
add_argument
(
"--epochs"
,
default
=
None
,
type
=
int
,
help
=
"Number of epochs (default: macro 310, micro 150)"
)
parser
.
add_argument
(
"--visualization"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
dataset_train
,
dataset_valid
=
datasets
.
get_dataset
(
"cifar10"
)
...
@@ -55,4 +56,6 @@ if __name__ == "__main__":
...
@@ -55,4 +56,6 @@ if __name__ == "__main__":
dataset_valid
=
dataset_valid
,
dataset_valid
=
dataset_valid
,
log_frequency
=
args
.
log_frequency
,
log_frequency
=
args
.
log_frequency
,
mutator
=
mutator
)
mutator
=
mutator
)
if
args
.
visualization
:
trainer
.
enable_visualization
()
trainer
.
train
()
trainer
.
train
()
examples/nas/naive/train.py
View file @
a7b96de9
...
@@ -68,5 +68,6 @@ if __name__ == "__main__":
...
@@ -68,5 +68,6 @@ if __name__ == "__main__":
dataset_valid
=
dataset_valid
,
dataset_valid
=
dataset_valid
,
batch_size
=
64
,
batch_size
=
64
,
log_frequency
=
10
)
log_frequency
=
10
)
trainer
.
enable_visualization
()
trainer
.
train
()
trainer
.
train
()
trainer
.
export
(
"checkpoint.json"
)
trainer
.
export
(
"checkpoint.json"
)
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
a7b96de9
...
@@ -129,6 +129,7 @@ class DartsTrainer(Trainer):
...
@@ -129,6 +129,7 @@ class DartsTrainer(Trainer):
self
.
mutator
.
reset
()
self
.
mutator
.
reset
()
logits
=
self
.
model
(
X
)
logits
=
self
.
model
(
X
)
loss
=
self
.
loss
(
logits
,
y
)
loss
=
self
.
loss
(
logits
,
y
)
self
.
_write_graph_status
()
return
logits
,
loss
return
logits
,
loss
def
_backward
(
self
,
val_X
,
val_y
):
def
_backward
(
self
,
val_X
,
val_y
):
...
...
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
View file @
a7b96de9
...
@@ -126,6 +126,7 @@ class EnasTrainer(Trainer):
...
@@ -126,6 +126,7 @@ class EnasTrainer(Trainer):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
mutator
.
reset
()
self
.
mutator
.
reset
()
self
.
_write_graph_status
()
logits
=
self
.
model
(
x
)
logits
=
self
.
model
(
x
)
if
isinstance
(
logits
,
tuple
):
if
isinstance
(
logits
,
tuple
):
...
@@ -159,6 +160,7 @@ class EnasTrainer(Trainer):
...
@@ -159,6 +160,7 @@ class EnasTrainer(Trainer):
self
.
mutator
.
reset
()
self
.
mutator
.
reset
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
logits
=
self
.
model
(
x
)
self
.
_write_graph_status
()
metrics
=
self
.
metrics
(
logits
,
y
)
metrics
=
self
.
metrics
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
if
self
.
entropy_weight
:
if
self
.
entropy_weight
:
...
...
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
a7b96de9
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
import
json
import
json
import
logging
import
logging
import
os
import
time
from
abc
import
abstractmethod
from
abc
import
abstractmethod
import
torch
import
torch
...
@@ -90,6 +92,9 @@ class Trainer(BaseTrainer):
...
@@ -90,6 +92,9 @@ class Trainer(BaseTrainer):
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
workers
=
workers
self
.
workers
=
workers
self
.
log_frequency
=
log_frequency
self
.
log_frequency
=
log_frequency
self
.
log_dir
=
os
.
path
.
join
(
"logs"
,
str
(
time
.
time
()))
os
.
makedirs
(
self
.
log_dir
,
exist_ok
=
True
)
self
.
status_writer
=
open
(
os
.
path
.
join
(
self
.
log_dir
,
"log"
),
"w"
)
self
.
callbacks
=
callbacks
if
callbacks
is
not
None
else
[]
self
.
callbacks
=
callbacks
if
callbacks
is
not
None
else
[]
for
callback
in
self
.
callbacks
:
for
callback
in
self
.
callbacks
:
callback
.
build
(
self
.
model
,
self
.
mutator
,
self
)
callback
.
build
(
self
.
model
,
self
.
mutator
,
self
)
...
@@ -168,3 +173,22 @@ class Trainer(BaseTrainer):
...
@@ -168,3 +173,22 @@ class Trainer(BaseTrainer):
Return trainer checkpoint.
Return trainer checkpoint.
"""
"""
raise
NotImplementedError
(
"Not implemented yet"
)
raise
NotImplementedError
(
"Not implemented yet"
)
def
enable_visualization
(
self
):
"""
Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
"""
sample
=
None
for
x
,
_
in
self
.
train_loader
:
sample
=
x
.
to
(
self
.
device
)[:
2
]
break
if
sample
is
None
:
_logger
.
warning
(
"Sample is %s."
,
sample
)
_logger
.
info
(
"Creating graph json, writing to %s. Visualization enabled."
,
self
.
log_dir
)
with
open
(
os
.
path
.
join
(
self
.
log_dir
,
"graph.json"
),
"w"
)
as
f
:
json
.
dump
(
self
.
mutator
.
graph
(
sample
),
f
)
self
.
visualization_enabled
=
True
def
_write_graph_status
(
self
):
if
hasattr
(
self
,
"visualization_enabled"
)
and
self
.
visualization_enabled
:
print
(
json
.
dumps
(
self
.
mutator
.
status
()),
file
=
self
.
status_writer
,
flush
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment