Unverified Commit f46e5f8e authored by gengenkai's avatar gengenkai Committed by GitHub
Browse files

[Fix] Fix add_graph in pavi (#948)

* [Fix] Fix add_graph in pavi

* change data loader to image

* Delete =2.4.0

* pavi-add_graph-0419

* pavi-add_graph-0419

* [Fix] pavi device

* fix device in pavi-add graph

* img_key

* img_key

* add no_grad

* Delete version.py

* add version.py
parent c142eced
......@@ -6,6 +6,7 @@ import os.path as osp
import yaml
import mmcv
from ....parallel.utils import is_module_wrapper
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......@@ -21,12 +22,14 @@ class PaviLoggerHook(LoggerHook):
interval=10,
ignore_last=True,
reset_flag=True,
by_epoch=True):
by_epoch=True,
img_key='img_info'):
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch)
self.init_kwargs = init_kwargs
self.add_graph = add_graph
self.add_last_ckpt = add_last_ckpt
self.img_key=img_key
@master_only
def before_run(self, runner):
......@@ -66,9 +69,6 @@ class PaviLoggerHook(LoggerHook):
self.init_kwargs['session_text'] = session_text
self.writer = SummaryWriter(**self.init_kwargs)
if self.add_graph:
self.writer.add_graph(runner.model)
def get_step(self, runner):
"""Get the total training step/epoch."""
if self.get_mode(runner) == 'val' and self.by_epoch:
......@@ -95,3 +95,16 @@ class PaviLoggerHook(LoggerHook):
tag=self.run_name,
snapshot_file_path=ckpt_path,
iteration=iteration)
@master_only
def before_epoch(self, runner):
if runner.epoch == 0 and self.add_graph:
if is_module_wrapper(runner.model):
_model = runner.model.module
else:
_model = runner.model
device = next(_model.parameters()).device
data = next(iter(runner.data_loader))
image = data[self.img_key][0:1].to(device)
with torch.no_grad():
self.writer.add_graph(_model, image)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment