Commit d796c13b authored by Kai Chen's avatar Kai Chen
Browse files

rename torchpack to runner

parent 51b73ac1
...@@ -7,3 +7,4 @@ from .image import * ...@@ -7,3 +7,4 @@ from .image import *
from .video import * from .video import *
from .visualization import * from .visualization import *
from .version import __version__ from .version import __version__
# runner is not imported here, so mmcv may be used without PyTorch
from .runner import Runner, LogBuffer from .runner import Runner
from .log_buffer import LogBuffer
from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook, from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook,
OptimizerHook, IterTimerHook, DistSamplerSeedHook, OptimizerHook, IterTimerHook, DistSamplerSeedHook,
LoggerHook, TextLoggerHook, PaviLoggerHook, LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook) TensorboardLoggerHook)
from .io import (load_state_dict, load_checkpoint, weights_to_cpu, from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
save_checkpoint) save_checkpoint)
from .parallel import parallel_test, worker_func from .parallel import parallel_test, worker_func
from .utils import (get_host_info, get_dist_info, master_only, get_time_str, from .utils import (get_host_info, get_dist_info, master_only, get_time_str,
add_file_handler, obj_from_dict) obj_from_dict)
__all__ = [ __all__ = [
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook', 'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
...@@ -15,5 +16,5 @@ __all__ = [ ...@@ -15,5 +16,5 @@ __all__ = [
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook', 'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint',
'parallel_test', 'worker_func', 'get_host_info', 'get_dist_info', 'parallel_test', 'worker_func', 'get_host_info', 'get_dist_info',
'master_only', 'get_time_str', 'add_file_handler', 'obj_from_dict' 'master_only', 'get_time_str', 'obj_from_dict'
] ]
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
......
...@@ -3,7 +3,7 @@ import multiprocessing ...@@ -3,7 +3,7 @@ import multiprocessing
import torch import torch
import mmcv import mmcv
from .io import load_checkpoint from .checkpoint import load_checkpoint
def worker_func(model_cls, model_kwargs, checkpoint, dataset, data_func, def worker_func(model_cls, model_kwargs, checkpoint, dataset, data_func,
......
...@@ -5,13 +5,12 @@ import time ...@@ -5,13 +5,12 @@ import time
import mmcv import mmcv
import torch import torch
from . import hooks
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .. import hooks from .hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook,
from ..hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook, OptimizerHook)
OptimizerHook) from .io import load_checkpoint, save_checkpoint
from ..io import load_checkpoint, save_checkpoint from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict
from ..utils import (get_dist_info, get_host_info, get_time_str,
add_file_handler, obj_from_dict)
class Runner(object): class Runner(object):
...@@ -128,6 +127,19 @@ class Runner(object): ...@@ -128,6 +127,19 @@ class Runner(object):
'but got {}'.format(type(optimizer))) 'but got {}'.format(type(optimizer)))
return optimizer return optimizer
def _add_file_handler(self,
logger,
filename=None,
mode='w',
level=logging.INFO):
# TODO: move this method out of runner
file_handler = logging.FileHandler(filename, mode)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
file_handler.setLevel(level)
logger.addHandler(file_handler)
return logger
def init_logger(self, log_dir=None, level=logging.INFO): def init_logger(self, log_dir=None, level=logging.INFO):
"""Init the logger. """Init the logger.
...@@ -145,7 +157,7 @@ class Runner(object): ...@@ -145,7 +157,7 @@ class Runner(object):
if log_dir: if log_dir:
filename = '{}_{}.log'.format(get_time_str(), self.rank) filename = '{}_{}.log'.format(get_time_str(), self.rank)
log_file = osp.join(log_dir, filename) log_file = osp.join(log_dir, filename)
add_file_handler(logger, log_file, level=level) self._add_file_handler(logger, log_file, level=level)
return logger return logger
def current_lr(self): def current_lr(self):
......
...@@ -38,14 +38,6 @@ def get_time_str(): ...@@ -38,14 +38,6 @@ def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime()) return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def add_file_handler(logger, filename=None, mode='w', level=logging.INFO):
file_handler = logging.FileHandler(filename, mode)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
return logger
def obj_from_dict(info, parrent=None, default_args=None): def obj_from_dict(info, parrent=None, default_args=None):
"""Initialize an object from dict. """Initialize an object from dict.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment