Commit fc5319b6 authored by pangjm's avatar pangjm
Browse files

add build_hook & other minor modify

parent ee54f9cf
...@@ -3,7 +3,7 @@ from __future__ import division ...@@ -3,7 +3,7 @@ from __future__ import division
import cv2 import cv2
import numpy as np import numpy as np
__all__ = ['imflip', 'imrotate', 'imcrop', 'impad', 'impad_to_multiple', 'bbox_flip'] __all__ = ['imflip', 'imrotate', 'imcrop', 'impad', 'impad_to_multiple']
def imflip(img, direction='horizontal'): def imflip(img, direction='horizontal'):
...@@ -111,20 +111,6 @@ def bbox_scaling(bboxes, scale, clip_shape=None): ...@@ -111,20 +111,6 @@ def bbox_scaling(bboxes, scale, clip_shape=None):
return scaled_bboxes return scaled_bboxes
def bbox_flip(bboxes, img_shape):
"""Flip bboxes horizontally
Args:
bboxes(ndarray): shape (..., 4*k)
img_shape(tuple): (height, width)
"""
assert bboxes.shape[-1] % 4 == 0
w = img_shape[1]
flipped = bboxes.copy()
flipped[..., 0::4] = w - bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - bboxes[..., 0::4] - 1
return flipped
def imcrop(img, bboxes, scale_ratio=1.0, pad_fill=None): def imcrop(img, bboxes, scale_ratio=1.0, pad_fill=None):
"""Crop image patches. """Crop image patches.
......
from .hook import Hook from .hook import Hook
from .checkpoint_saver import CheckpointSaverHook from .checkpoint_saver import CheckpointHook
from .closure import ClosureHook from .closure import ClosureHook
from .lr_updater import LrUpdaterHook from .lr_updater import LrUpdaterHook
from .optimizer_stepper import OptimizerStepperHook from .optimizer_stepper import OptimizerHook
from .iter_timer import IterTimerHook from .iter_timer import IterTimerHook
from .logger import * from .logger import *
...@@ -2,7 +2,7 @@ from .hook import Hook ...@@ -2,7 +2,7 @@ from .hook import Hook
from ..utils import master_only from ..utils import master_only
class CheckpointSaverHook(Hook): class CheckpointHook(Hook):
def __init__(self, def __init__(self,
interval=-1, interval=-1,
......
...@@ -3,7 +3,7 @@ from torch.nn.utils import clip_grad ...@@ -3,7 +3,7 @@ from torch.nn.utils import clip_grad
from .hook import Hook from .hook import Hook
class OptimizerStepperHook(Hook): class OptimizerHook(Hook):
def __init__(self, grad_clip=False, max_norm=35, norm_type=2): def __init__(self, grad_clip=False, max_norm=35, norm_type=2):
self.grad_clip = grad_clip self.grad_clip = grad_clip
......
...@@ -8,8 +8,8 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel ...@@ -8,8 +8,8 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .. import hooks from .. import hooks
from ..hooks import (Hook, LrUpdaterHook, CheckpointSaverHook, IterTimerHook, from ..hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook,
OptimizerStepperHook) OptimizerHook)
from ..io import load_checkpoint, save_checkpoint from ..io import load_checkpoint, save_checkpoint
from ..utils import (get_dist_info, get_host_info, get_time_str, from ..utils import (get_dist_info, get_host_info, get_time_str,
add_file_handler, obj_from_dict) add_file_handler, obj_from_dict)
...@@ -182,6 +182,16 @@ class Runner(object): ...@@ -182,6 +182,16 @@ class Runner(object):
if not inserted: if not inserted:
self._hooks.insert(0, hook) self._hooks.insert(0, hook)
def build_hook(self, hook, args):
assert issubclass(hook, Hook), '"hook" must be a Hook object'
if isinstance(args, dict):
self.register_hook(hook(**args))
elif isinstance(args, Hook):
self.register_hook(args)
else:
raise TypeError('"args" must be either a Hook object'
' or dict, not {}'.format(type(args)))
def call_hook(self, fn_name): def call_hook(self, fn_name):
for hook in self._hooks: for hook in self._hooks:
getattr(hook, fn_name)(self) getattr(hook, fn_name)(self)
...@@ -329,7 +339,7 @@ class Runner(object): ...@@ -329,7 +339,7 @@ class Runner(object):
def register_training_hooks(self, def register_training_hooks(self,
lr_config, lr_config,
grad_clip_config=None, optimizer_config=None,
checkpoint_config=None, checkpoint_config=None,
log_config=None): log_config=None):
"""Register default hooks for training. """Register default hooks for training.
...@@ -341,22 +351,13 @@ class Runner(object): ...@@ -341,22 +351,13 @@ class Runner(object):
- IterTimerHook - IterTimerHook
- LoggerHook - LoggerHook
""" """
if grad_clip_config is None: if optimizer_config is None:
grad_clip_config = {} optimizer_config = {}
if checkpoint_config is None: if checkpoint_config is None:
checkpoint_config = {} checkpoint_config = {}
self.register_lr_hooks(lr_config) self.register_lr_hooks(lr_config)
self.build_hook(OptimizerHook, optimizer_config)
if isinstance(grad_clip_config, Hook): self.build_hook(CheckpointHook, checkpoint_config)
self.register_hook(grad_clip_config)
elif isinstance(grad_clip_config, dict):
self.register_hook(OptimizerStepperHook(**grad_clip_config))
else:
raise TypeError(
"OptimizerStepperHook should be a Hook object or dict, not {}".
format(type(grad_clip_config)))
self.register_hook(CheckpointSaverHook(**checkpoint_config))
self.register_hook(IterTimerHook()) self.register_hook(IterTimerHook())
if log_config is not None: if log_config is not None:
self.register_logger_hooks(log_config) self.register_logger_hooks(log_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment