Unverified Commit 1248da3c authored by Wencheng Wu's avatar Wencheng Wu Committed by GitHub
Browse files

[Fix] Fix the upload graph error of PaviLoggerHook. (#2100)

* [Fix] Fix add_graph function of PaviLoggerHook.

* Fix circular reference error.

* Add comments.

* modify partial_args default to empty dict.

* Add warnning for add_graph and img_key parameter.
parent 78f01001
...@@ -2,14 +2,17 @@ ...@@ -2,14 +2,17 @@
import json import json
import os import os
import os.path as osp import os.path as osp
import warnings
from functools import partial
from typing import Dict, Optional from typing import Dict, Optional
import torch import torch
import yaml import yaml
import mmcv import mmcv
from ....parallel.utils import is_module_wrapper from mmcv.parallel.scatter_gather import scatter
from ...dist_utils import master_only from mmcv.parallel.utils import is_module_wrapper
from mmcv.runner.dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -35,8 +38,9 @@ class PaviLoggerHook(LoggerHook): ...@@ -35,8 +38,9 @@ class PaviLoggerHook(LoggerHook):
- overwrite_last_training (bool, optional): Whether to upload data - overwrite_last_training (bool, optional): Whether to upload data
to the training with the same name in the same project, rather to the training with the same name in the same project, rather
than creating a new one. Defaults to False. than creating a new one. Defaults to False.
add_graph (bool): **Deprecated**. Whether to visual model. add_graph (bool, optional): **Deprecated**. Whether to visual model.
Default: False. Default: False.
img_key (str, optional): **Deprecated**. Image key. Defaults to None.
add_last_ckpt (bool): Whether to save checkpoint after run. add_last_ckpt (bool): Whether to save checkpoint after run.
Default: False. Default: False.
interval (int): Logging interval (every k iterations). Default: True. interval (int): Logging interval (every k iterations). Default: True.
...@@ -45,39 +49,64 @@ class PaviLoggerHook(LoggerHook): ...@@ -45,39 +49,64 @@ class PaviLoggerHook(LoggerHook):
reset_flag (bool): Whether to clear the output buffer after logging. reset_flag (bool): Whether to clear the output buffer after logging.
Default: False. Default: False.
by_epoch (bool): Whether EpochBasedRunner is used. Default: True. by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
img_key (string): Get image data from Dataset. Default: 'img_info'.
add_graph_kwargs (dict, optional): A dict contains the params for add_graph_kwargs (dict, optional): A dict contains the params for
adding graph, the keys are as below: adding graph, the keys are as below:
Default: {'active': False, 'start': 0, 'interval': 1}. - active (bool): Whether to use ``add_graph``. Default: False.
- start (int): The epoch or iteration to start. Default: 0.
- interval (int): Interval of ``add_graph``. Default: 1.
- img_key (str): Get image data from Dataset. Default: 'img'.
- opset_version (int): ``opset_version`` of exporting onnx.
Default: 11.
- dummy_forward_kwargs (dict, optional): Set default parameters to
model forward function except image. For example, you can set
{'return_loss': False} for mmcls. Default: None.
add_ckpt_kwargs (dict, optional): A dict contains the params for add_ckpt_kwargs (dict, optional): A dict contains the params for
adding checkpoint, the keys are as below: adding checkpoint, the keys are as below:
Default: {'active': False, 'start': 0, 'interval': 1}. - active (bool): Whether to upload checkpoint. Default: False.
- start (int): The epoch or iteration to start. Default: 0.
- interval (int): Interval of upload checkpoint. Default: 1.
""" """
def __init__(self, def __init__(self,
init_kwargs: Optional[Dict] = None, init_kwargs: Optional[Dict] = None,
add_graph: bool = False, add_graph: Optional[bool] = None,
img_key: Optional[str] = None,
add_last_ckpt: bool = False, add_last_ckpt: bool = False,
interval: int = 10, interval: int = 10,
ignore_last: bool = True, ignore_last: bool = True,
reset_flag: bool = False, reset_flag: bool = False,
by_epoch: bool = True, by_epoch: bool = True,
img_key: str = 'img_info',
add_graph_kwargs: Optional[Dict] = None, add_graph_kwargs: Optional[Dict] = None,
add_ckpt_kwargs: Optional[Dict] = None) -> None: add_ckpt_kwargs: Optional[Dict] = None) -> None:
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
add_graph_kwargs = {} if add_graph_kwargs is None else add_graph_kwargs add_graph_kwargs = {} if add_graph_kwargs is None else add_graph_kwargs
self.add_graph = add_graph_kwargs.get('active', False) self.add_graph = add_graph_kwargs.get('active', False)
self.add_graph_start = add_graph_kwargs.get('start', 0) self.add_graph_start = add_graph_kwargs.get('start', 0)
self.add_graph_interval = add_graph_kwargs.get('interval', 1) self.add_graph_interval = add_graph_kwargs.get('interval', 1)
self.img_key = add_graph_kwargs.get('img_key', 'img')
self.opset_version = add_graph_kwargs.get('opset_version', 11)
self.dummy_forward_kwargs = add_graph_kwargs.get(
'dummy_forward_kwargs', {})
if add_graph is not None:
warnings.warn(
'"add_graph" is deprecated in `PaviLoggerHook`, please use '
'the key "active" of add_graph_kwargs instead',
DeprecationWarning)
self.add_graph = add_graph
if img_key is not None:
warnings.warn(
'"img_key" is deprecated in `PaviLoggerHook`, please use '
'the key "img_key" of add_graph_kwargs instead',
DeprecationWarning)
self.img_key = img_key
add_ckpt_kwargs = {} if add_ckpt_kwargs is None else add_ckpt_kwargs add_ckpt_kwargs = {} if add_ckpt_kwargs is None else add_ckpt_kwargs
self.add_ckpt = add_ckpt_kwargs.get('active', False) self.add_ckpt = add_ckpt_kwargs.get('active', False)
self.add_last_ckpt = add_last_ckpt self.add_last_ckpt = add_last_ckpt
self.add_ckpt_start = add_ckpt_kwargs.get('start', 0) self.add_ckpt_start = add_ckpt_kwargs.get('start', 0)
self.add_ckpt_interval = add_ckpt_kwargs.get('interval', 1) self.add_ckpt_interval = add_ckpt_kwargs.get('interval', 1)
self.img_key = img_key
@master_only @master_only
def before_run(self, runner) -> None: def before_run(self, runner) -> None:
...@@ -138,16 +167,34 @@ class PaviLoggerHook(LoggerHook): ...@@ -138,16 +167,34 @@ class PaviLoggerHook(LoggerHook):
snapshot_file_path=ckpt_path, snapshot_file_path=ckpt_path,
iteration=step) iteration=step)
def _add_graph(self, runner) -> None: def _add_graph(self, runner, step: int) -> None:
from mmcv.runner.iter_based_runner import IterLoader
if is_module_wrapper(runner.model): if is_module_wrapper(runner.model):
_model = runner.model.module _model = runner.model.module
else: else:
_model = runner.model _model = runner.model
device = next(_model.parameters()).device device = next(_model.parameters()).device
data = next(iter(runner.data_loader)) # Note that if your sampler indices is generated in init method, your
image = data[self.img_key][0:1].to(device) # dataset may be one less.
if isinstance(runner.data_loader, IterLoader):
data = next(iter(runner.data_loader._dataloader))
else:
data = next(iter(runner.data_loader))
data = scatter(data, [device.index])[0]
img = data[self.img_key]
with torch.no_grad(): with torch.no_grad():
self.writer.add_graph(_model, image) origin_forward = _model.forward
if hasattr(_model, 'forward_dummy'):
_model.forward = _model.forward_dummy
if self.dummy_forward_kwargs:
_model.forward = partial(_model.forward,
**self.dummy_forward_kwargs)
self.writer.add_graph(
_model,
img,
tag=f'{self.run_name}_{step}',
opset_version=self.opset_version)
_model.forward = origin_forward
@master_only @master_only
def log(self, runner) -> None: def log(self, runner) -> None:
...@@ -180,7 +227,7 @@ class PaviLoggerHook(LoggerHook): ...@@ -180,7 +227,7 @@ class PaviLoggerHook(LoggerHook):
if (self.add_graph and step >= self.add_graph_start if (self.add_graph and step >= self.add_graph_start
and ((step - self.add_graph_start) % self.add_graph_interval and ((step - self.add_graph_start) % self.add_graph_interval
== 0)): # noqa: E129 == 0)): # noqa: E129
self._add_graph(runner) self._add_graph(runner, step)
@master_only @master_only
def before_train_iter(self, runner) -> None: def before_train_iter(self, runner) -> None:
...@@ -193,7 +240,7 @@ class PaviLoggerHook(LoggerHook): ...@@ -193,7 +240,7 @@ class PaviLoggerHook(LoggerHook):
if (self.add_graph and step >= self.add_graph_start if (self.add_graph and step >= self.add_graph_start
and ((step - self.add_graph_start) % self.add_graph_interval and ((step - self.add_graph_start) % self.add_graph_interval
== 0)): # noqa: E129 == 0)): # noqa: E129
self._add_graph(runner) self._add_graph(runner, step)
@master_only @master_only
def after_train_epoch(self, runner) -> None: def after_train_epoch(self, runner) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment