Commit fc5319b6 authored by pangjm's avatar pangjm
Browse files

add build_hook & other minor modify

parent ee54f9cf
......@@ -3,7 +3,7 @@ from __future__ import division
import cv2
import numpy as np
__all__ = ['imflip', 'imrotate', 'imcrop', 'impad', 'impad_to_multiple', 'bbox_flip']
__all__ = ['imflip', 'imrotate', 'imcrop', 'impad', 'impad_to_multiple']
def imflip(img, direction='horizontal'):
......@@ -111,20 +111,6 @@ def bbox_scaling(bboxes, scale, clip_shape=None):
return scaled_bboxes
def bbox_flip(bboxes, img_shape):
"""Flip bboxes horizontally
Args:
bboxes(ndarray): shape (..., 4*k)
img_shape(tuple): (height, width)
"""
assert bboxes.shape[-1] % 4 == 0
w = img_shape[1]
flipped = bboxes.copy()
flipped[..., 0::4] = w - bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - bboxes[..., 0::4] - 1
return flipped
def imcrop(img, bboxes, scale_ratio=1.0, pad_fill=None):
"""Crop image patches.
......
from .hook import Hook
from .checkpoint_saver import CheckpointSaverHook
from .checkpoint_saver import CheckpointHook
from .closure import ClosureHook
from .lr_updater import LrUpdaterHook
from .optimizer_stepper import OptimizerStepperHook
from .optimizer_stepper import OptimizerHook
from .iter_timer import IterTimerHook
from .logger import *
......@@ -2,7 +2,7 @@ from .hook import Hook
from ..utils import master_only
class CheckpointSaverHook(Hook):
class CheckpointHook(Hook):
def __init__(self,
interval=-1,
......
......@@ -3,7 +3,7 @@ from torch.nn.utils import clip_grad
from .hook import Hook
class OptimizerStepperHook(Hook):
class OptimizerHook(Hook):
def __init__(self, grad_clip=False, max_norm=35, norm_type=2):
self.grad_clip = grad_clip
......
......@@ -8,8 +8,8 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
from .log_buffer import LogBuffer
from .. import hooks
from ..hooks import (Hook, LrUpdaterHook, CheckpointSaverHook, IterTimerHook,
OptimizerStepperHook)
from ..hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook,
OptimizerHook)
from ..io import load_checkpoint, save_checkpoint
from ..utils import (get_dist_info, get_host_info, get_time_str,
add_file_handler, obj_from_dict)
......@@ -182,6 +182,16 @@ class Runner(object):
if not inserted:
self._hooks.insert(0, hook)
def build_hook(self, hook, args):
assert issubclass(hook, Hook), '"hook" must be a Hook object'
if isinstance(args, dict):
self.register_hook(hook(**args))
elif isinstance(args, Hook):
self.register_hook(args)
else:
raise TypeError('"args" must be either a Hook object'
' or dict, not {}'.format(type(args)))
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
......@@ -329,7 +339,7 @@ class Runner(object):
def register_training_hooks(self,
lr_config,
grad_clip_config=None,
optimizer_config=None,
checkpoint_config=None,
log_config=None):
"""Register default hooks for training.
......@@ -341,22 +351,13 @@ class Runner(object):
- IterTimerHook
- LoggerHook
"""
if grad_clip_config is None:
grad_clip_config = {}
if optimizer_config is None:
optimizer_config = {}
if checkpoint_config is None:
checkpoint_config = {}
self.register_lr_hooks(lr_config)
if isinstance(grad_clip_config, Hook):
self.register_hook(grad_clip_config)
elif isinstance(grad_clip_config, dict):
self.register_hook(OptimizerStepperHook(**grad_clip_config))
else:
raise TypeError(
"OptimizerStepperHook should be a Hook object or dict, not {}".
format(type(grad_clip_config)))
self.register_hook(CheckpointSaverHook(**checkpoint_config))
self.build_hook(OptimizerHook, optimizer_config)
self.build_hook(CheckpointHook, checkpoint_config)
self.register_hook(IterTimerHook())
if log_config is not None:
self.register_logger_hooks(log_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment