Commit ee54f9cf authored by pangjm's avatar pangjm
Browse files

fix checkpoint & runner bugs

parent 685e8f99
...@@ -3,7 +3,7 @@ from __future__ import division ...@@ -3,7 +3,7 @@ from __future__ import division
import cv2 import cv2
import numpy as np import numpy as np
__all__ = ['imflip', 'imrotate', 'imcrop', 'impad', 'impad_to_multiple'] __all__ = ['imflip', 'imrotate', 'imcrop', 'impad', 'impad_to_multiple', 'bbox_flip']
def imflip(img, direction='horizontal'): def imflip(img, direction='horizontal'):
...@@ -111,6 +111,20 @@ def bbox_scaling(bboxes, scale, clip_shape=None): ...@@ -111,6 +111,20 @@ def bbox_scaling(bboxes, scale, clip_shape=None):
return scaled_bboxes return scaled_bboxes
def bbox_flip(bboxes, img_shape):
"""Flip bboxes horizontally
Args:
bboxes(ndarray): shape (..., 4*k)
img_shape(tuple): (height, width)
"""
assert bboxes.shape[-1] % 4 == 0
w = img_shape[1]
flipped = bboxes.copy()
flipped[..., 0::4] = w - bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - bboxes[..., 0::4] - 1
return flipped
def imcrop(img, bboxes, scale_ratio=1.0, pad_fill=None): def imcrop(img, bboxes, scale_ratio=1.0, pad_fill=None):
"""Crop image patches. """Crop image patches.
......
...@@ -201,7 +201,7 @@ class Runner(object): ...@@ -201,7 +201,7 @@ class Runner(object):
else: else:
meta.update(epoch=self.epoch + 1, iter=self.iter) meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = osp.join(out_dir, filename_tmpl.format(self.epoch)) filename = osp.join(out_dir, filename_tmpl.format(self.epoch + 1))
linkname = osp.join(out_dir, 'latest.pth') linkname = osp.join(out_dir, 'latest.pth')
optimizer = self.optimizer if save_optimizer else None optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filename, optimizer=optimizer, meta=meta) save_checkpoint(self.model, filename, optimizer=optimizer, meta=meta)
...@@ -213,7 +213,6 @@ class Runner(object): ...@@ -213,7 +213,6 @@ class Runner(object):
self.data_loader = data_loader self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader) self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch') self.call_hook('before_train_epoch')
for i, data_batch in enumerate(data_loader): for i, data_batch in enumerate(data_loader):
self._inner_iter = i self._inner_iter = i
self.call_hook('before_train_iter') self.call_hook('before_train_iter')
...@@ -347,7 +346,16 @@ class Runner(object): ...@@ -347,7 +346,16 @@ class Runner(object):
if checkpoint_config is None: if checkpoint_config is None:
checkpoint_config = {} checkpoint_config = {}
self.register_lr_hooks(lr_config) self.register_lr_hooks(lr_config)
if isinstance(grad_clip_config, Hook):
self.register_hook(grad_clip_config)
elif isinstance(grad_clip_config, dict):
self.register_hook(OptimizerStepperHook(**grad_clip_config)) self.register_hook(OptimizerStepperHook(**grad_clip_config))
else:
raise TypeError(
"OptimizerStepperHook should be a Hook object or dict, not {}".
format(type(grad_clip_config)))
self.register_hook(CheckpointSaverHook(**checkpoint_config)) self.register_hook(CheckpointSaverHook(**checkpoint_config))
self.register_hook(IterTimerHook()) self.register_hook(IterTimerHook())
if log_config is not None: if log_config is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment