Unverified Commit a7b96de9 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Enable visualization in examples (#2261)

parent 63bd0f50
......@@ -24,6 +24,7 @@ if __name__ == "__main__":
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--channels", default=16, type=int)
parser.add_argument("--unrolled", default=False, action="store_true")
parser.add_argument("--visualization", default=False, action="store_true")
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
......@@ -45,4 +46,6 @@ if __name__ == "__main__":
log_frequency=args.log_frequency,
unrolled=args.unrolled,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
if args.visualization:
trainer.enable_visualization()
trainer.train()
......@@ -25,6 +25,7 @@ if __name__ == "__main__":
parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
parser.add_argument("--visualization", default=False, action="store_true")
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
......@@ -55,4 +56,6 @@ if __name__ == "__main__":
dataset_valid=dataset_valid,
log_frequency=args.log_frequency,
mutator=mutator)
if args.visualization:
trainer.enable_visualization()
trainer.train()
......@@ -68,5 +68,6 @@ if __name__ == "__main__":
dataset_valid=dataset_valid,
batch_size=64,
log_frequency=10)
trainer.enable_visualization()
trainer.train()
trainer.export("checkpoint.json")
......@@ -129,6 +129,7 @@ class DartsTrainer(Trainer):
self.mutator.reset()
logits = self.model(X)
loss = self.loss(logits, y)
self._write_graph_status()
return logits, loss
def _backward(self, val_X, val_y):
......
......@@ -126,6 +126,7 @@ class EnasTrainer(Trainer):
with torch.no_grad():
self.mutator.reset()
self._write_graph_status()
logits = self.model(x)
if isinstance(logits, tuple):
......@@ -159,6 +160,7 @@ class EnasTrainer(Trainer):
self.mutator.reset()
with torch.no_grad():
logits = self.model(x)
self._write_graph_status()
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight:
......
......@@ -3,6 +3,8 @@
import json
import logging
import os
import time
from abc import abstractmethod
import torch
......@@ -90,6 +92,9 @@ class Trainer(BaseTrainer):
self.batch_size = batch_size
self.workers = workers
self.log_frequency = log_frequency
self.log_dir = os.path.join("logs", str(time.time()))
os.makedirs(self.log_dir, exist_ok=True)
self.status_writer = open(os.path.join(self.log_dir, "log"), "w")
self.callbacks = callbacks if callbacks is not None else []
for callback in self.callbacks:
callback.build(self.model, self.mutator, self)
......@@ -168,3 +173,22 @@ class Trainer(BaseTrainer):
Return trainer checkpoint.
"""
raise NotImplementedError("Not implemented yet")
def enable_visualization(self):
"""
Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
"""
sample = None
for x, _ in self.train_loader:
sample = x.to(self.device)[:2]
break
if sample is None:
_logger.warning("Sample is %s.", sample)
_logger.info("Creating graph json, writing to %s. Visualization enabled.", self.log_dir)
with open(os.path.join(self.log_dir, "graph.json"), "w") as f:
json.dump(self.mutator.graph(sample), f)
self.visualization_enabled = True
def _write_graph_status(self):
if hasattr(self, "visualization_enabled") and self.visualization_enabled:
print(json.dumps(self.mutator.status()), file=self.status_writer, flush=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment