Commit 523f861b authored by Kai Chen's avatar Kai Chen
Browse files

add priority enum

parent f573c11c
...@@ -7,6 +7,7 @@ from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook, ...@@ -7,6 +7,7 @@ from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook,
from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu, from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
save_checkpoint) save_checkpoint)
from .parallel import parallel_test, worker_func from .parallel import parallel_test, worker_func
from .priority import Priority, get_priority
from .utils import (get_host_info, get_dist_info, master_only, get_time_str, from .utils import (get_host_info, get_dist_info, master_only, get_time_str,
obj_from_dict) obj_from_dict)
...@@ -15,6 +16,7 @@ __all__ = [ ...@@ -15,6 +16,7 @@ __all__ = [
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook', 'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint',
'parallel_test', 'worker_func', 'get_host_info', 'get_dist_info', 'parallel_test', 'worker_func', 'Priority', 'get_priority',
'master_only', 'get_time_str', 'obj_from_dict' 'get_host_info', 'get_dist_info', 'master_only', 'get_time_str',
'obj_from_dict'
] ]
from enum import Enum
class Priority(Enum):
HIGHEST = 0
VERY_HIGH = 20
HIGH = 40
NORMAL = 50
LOW = 60
VERY_LOW = 80
LOWEST = 100
def get_priority(priority):
"""Get priority value.
Args:
priority (int or str or :obj:`Priority`): Priority.
Returns:
int: The priority value.
"""
if isinstance(priority, int):
if priority < 0 or priority > 100:
raise ValueError('priority must be between 0 and 100')
return priority
elif isinstance(priority, Priority):
return priority.value
elif isinstance(priority, str):
return Priority[priority.upper()].value
else:
raise TypeError('priority must be an integer or Priority enum value')
...@@ -10,11 +10,24 @@ from .log_buffer import LogBuffer ...@@ -10,11 +10,24 @@ from .log_buffer import LogBuffer
from .hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook, from .hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook,
OptimizerHook, lr_updater) OptimizerHook, lr_updater)
from .checkpoint import load_checkpoint, save_checkpoint from .checkpoint import load_checkpoint, save_checkpoint
from .priority import get_priority
from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict
class Runner(object): class Runner(object):
"""A training helper for PyTorch.""" """A training helper for PyTorch.
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
runner will construct an optimizer according to it.
work_dir (str, optional): The working directory to save checkpoints
and logs.
log_level (int): Logging level.
"""
def __init__(self, def __init__(self,
model, model,
...@@ -154,8 +167,8 @@ class Runner(object): ...@@ -154,8 +167,8 @@ class Runner(object):
logging.basicConfig( logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=level) format='%(asctime)s - %(levelname)s - %(message)s', level=level)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if log_dir: if log_dir and self.rank == 0:
filename = '{}_{}.log'.format(get_time_str(), self.rank) filename = '{}.log'.format(get_time_str())
log_file = osp.join(log_dir, filename) log_file = osp.join(log_dir, filename)
self._add_file_handler(logger, log_file, level=level) self._add_file_handler(logger, log_file, level=level)
return logger return logger
...@@ -171,17 +184,18 @@ class Runner(object): ...@@ -171,17 +184,18 @@ class Runner(object):
'lr is not applicable because optimizer does not exist.') 'lr is not applicable because optimizer does not exist.')
return [group['lr'] for group in self.optimizer.param_groups] return [group['lr'] for group in self.optimizer.param_groups]
def register_hook(self, hook, priority=50): def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list. """Register a hook into the hook list.
Args: Args:
hook (:obj:`Hook`): The hook to be registered. hook (:obj:`Hook`): The hook to be registered.
priority (int): Hook priority. Lower value means higher priority. priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
""" """
assert isinstance(hook, Hook) assert isinstance(hook, Hook)
assert isinstance(priority, int) and priority >= 0 and priority <= 100
if hasattr(hook, 'priority'): if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks') raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority hook.priority = priority
# insert the hook to a sorted list # insert the hook to a sorted list
inserted = False inserted = False
...@@ -292,6 +306,17 @@ class Runner(object): ...@@ -292,6 +306,17 @@ class Runner(object):
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def run(self, data_loaders, workflow, max_epochs, **kwargs): def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list) assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple) assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow) assert len(data_loaders) == len(workflow)
...@@ -346,7 +371,7 @@ class Runner(object): ...@@ -346,7 +371,7 @@ class Runner(object):
for info in log_config['hooks']: for info in log_config['hooks']:
logger_hook = obj_from_dict( logger_hook = obj_from_dict(
info, hooks, default_args=dict(interval=log_interval)) info, hooks, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority=60) self.register_hook(logger_hook, priority='VERY_LOW')
def register_training_hooks(self, def register_training_hooks(self,
lr_config, lr_config,
...@@ -360,7 +385,7 @@ class Runner(object): ...@@ -360,7 +385,7 @@ class Runner(object):
- OptimizerStepperHook - OptimizerStepperHook
- CheckpointSaverHook - CheckpointSaverHook
- IterTimerHook - IterTimerHook
- LoggerHook - LoggerHook(s)
""" """
if optimizer_config is None: if optimizer_config is None:
optimizer_config = {} optimizer_config = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment