Unverified Commit 1248da3c authored by Wencheng Wu's avatar Wencheng Wu Committed by GitHub
Browse files

[Fix] Fix the upload graph error of PaviLoggerHook. (#2100)

* [Fix] Fix add_graph function of PaviLoggerHook.

* Fix circular reference error.

* Add comments.

* modify partial_args default to empty dict.

* Add warnning for add_graph and img_key parameter.
parent 78f01001
......@@ -2,14 +2,17 @@
import json
import os
import os.path as osp
import warnings
from functools import partial
from typing import Dict, Optional
import torch
import yaml
import mmcv
from ....parallel.utils import is_module_wrapper
from ...dist_utils import master_only
from mmcv.parallel.scatter_gather import scatter
from mmcv.parallel.utils import is_module_wrapper
from mmcv.runner.dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......@@ -35,8 +38,9 @@ class PaviLoggerHook(LoggerHook):
- overwrite_last_training (bool, optional): Whether to upload data
to the training with the same name in the same project, rather
than creating a new one. Defaults to False.
add_graph (bool): **Deprecated**. Whether to visual model.
add_graph (bool, optional): **Deprecated**. Whether to visual model.
Default: False.
img_key (str, optional): **Deprecated**. Image key. Defaults to None.
add_last_ckpt (bool): Whether to save checkpoint after run.
Default: False.
interval (int): Logging interval (every k iterations). Default: True.
......@@ -45,39 +49,64 @@ class PaviLoggerHook(LoggerHook):
reset_flag (bool): Whether to clear the output buffer after logging.
Default: False.
by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
img_key (string): Get image data from Dataset. Default: 'img_info'.
add_graph_kwargs (dict, optional): A dict contains the params for
adding graph, the keys are as below:
Default: {'active': False, 'start': 0, 'interval': 1}.
- active (bool): Whether to use ``add_graph``. Default: False.
- start (int): The epoch or iteration to start. Default: 0.
- interval (int): Interval of ``add_graph``. Default: 1.
- img_key (str): Get image data from Dataset. Default: 'img'.
- opset_version (int): ``opset_version`` of exporting onnx.
Default: 11.
- dummy_forward_kwargs (dict, optional): Set default parameters to
model forward function except image. For example, you can set
{'return_loss': False} for mmcls. Default: None.
add_ckpt_kwargs (dict, optional): A dict contains the params for
adding checkpoint, the keys are as below:
Default: {'active': False, 'start': 0, 'interval': 1}.
- active (bool): Whether to upload checkpoint. Default: False.
- start (int): The epoch or iteration to start. Default: 0.
- interval (int): Interval of upload checkpoint. Default: 1.
"""
def __init__(self,
init_kwargs: Optional[Dict] = None,
add_graph: bool = False,
add_graph: Optional[bool] = None,
img_key: Optional[str] = None,
add_last_ckpt: bool = False,
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = False,
by_epoch: bool = True,
img_key: str = 'img_info',
add_graph_kwargs: Optional[Dict] = None,
add_ckpt_kwargs: Optional[Dict] = None) -> None:
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.init_kwargs = init_kwargs
add_graph_kwargs = {} if add_graph_kwargs is None else add_graph_kwargs
self.add_graph = add_graph_kwargs.get('active', False)
self.add_graph_start = add_graph_kwargs.get('start', 0)
self.add_graph_interval = add_graph_kwargs.get('interval', 1)
self.img_key = add_graph_kwargs.get('img_key', 'img')
self.opset_version = add_graph_kwargs.get('opset_version', 11)
self.dummy_forward_kwargs = add_graph_kwargs.get(
'dummy_forward_kwargs', {})
if add_graph is not None:
warnings.warn(
'"add_graph" is deprecated in `PaviLoggerHook`, please use '
'the key "active" of add_graph_kwargs instead',
DeprecationWarning)
self.add_graph = add_graph
if img_key is not None:
warnings.warn(
'"img_key" is deprecated in `PaviLoggerHook`, please use '
'the key "img_key" of add_graph_kwargs instead',
DeprecationWarning)
self.img_key = img_key
add_ckpt_kwargs = {} if add_ckpt_kwargs is None else add_ckpt_kwargs
self.add_ckpt = add_ckpt_kwargs.get('active', False)
self.add_last_ckpt = add_last_ckpt
self.add_ckpt_start = add_ckpt_kwargs.get('start', 0)
self.add_ckpt_interval = add_ckpt_kwargs.get('interval', 1)
self.img_key = img_key
@master_only
def before_run(self, runner) -> None:
......@@ -138,16 +167,34 @@ class PaviLoggerHook(LoggerHook):
snapshot_file_path=ckpt_path,
iteration=step)
def _add_graph(self, runner) -> None:
def _add_graph(self, runner, step: int) -> None:
from mmcv.runner.iter_based_runner import IterLoader
if is_module_wrapper(runner.model):
_model = runner.model.module
else:
_model = runner.model
device = next(_model.parameters()).device
data = next(iter(runner.data_loader))
image = data[self.img_key][0:1].to(device)
# Note that if your sampler indices is generated in init method, your
# dataset may be one less.
if isinstance(runner.data_loader, IterLoader):
data = next(iter(runner.data_loader._dataloader))
else:
data = next(iter(runner.data_loader))
data = scatter(data, [device.index])[0]
img = data[self.img_key]
with torch.no_grad():
self.writer.add_graph(_model, image)
origin_forward = _model.forward
if hasattr(_model, 'forward_dummy'):
_model.forward = _model.forward_dummy
if self.dummy_forward_kwargs:
_model.forward = partial(_model.forward,
**self.dummy_forward_kwargs)
self.writer.add_graph(
_model,
img,
tag=f'{self.run_name}_{step}',
opset_version=self.opset_version)
_model.forward = origin_forward
@master_only
def log(self, runner) -> None:
......@@ -180,7 +227,7 @@ class PaviLoggerHook(LoggerHook):
if (self.add_graph and step >= self.add_graph_start
and ((step - self.add_graph_start) % self.add_graph_interval
== 0)): # noqa: E129
self._add_graph(runner)
self._add_graph(runner, step)
@master_only
def before_train_iter(self, runner) -> None:
......@@ -193,7 +240,7 @@ class PaviLoggerHook(LoggerHook):
if (self.add_graph and step >= self.add_graph_start
and ((step - self.add_graph_start) % self.add_graph_interval
== 0)): # noqa: E129
self._add_graph(runner)
self._add_graph(runner, step)
@master_only
def after_train_epoch(self, runner) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment