Commit ffdc1d45 authored by Kai Chen's avatar Kai Chen
Browse files

add initial version of torchpack

parent 02ceae83
import logging
import os.path as osp
import time
import mmcv
import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
from .log_buffer import LogBuffer
from .. import hooks
from ..hooks import (Hook, LrUpdaterHook, CheckpointSaverHook, IterTimerHook,
OptimizerStepperHook)
from ..io import load_checkpoint, save_checkpoint
from ..utils import (get_dist_info, get_host_info, get_time_str,
add_file_handler, obj_from_dict)
class Runner(object):
"""A training helper for PyTorch."""
def __init__(self,
model,
optimizer,
batch_processor,
work_dir=None,
log_level=logging.INFO):
assert callable(batch_processor)
self.model = model
self.optimizer = self.init_optimizer(optimizer)
self.batch_processor = batch_processor
# create work_dir
if mmcv.is_str(work_dir):
self.work_dir = osp.abspath(work_dir)
mmcv.mkdir_or_exist(self.work_dir)
elif work_dir is None:
self.work_dir = None
else:
raise TypeError('"work_dir" must be a str or None')
# get model name from the model class
if isinstance(self.model, (DataParallel, DistributedDataParallel)):
self._model_name = self.model.module.__class__.__name__
else:
self._model_name = self.model.__class__.__name__
self._rank, self._world_size = get_dist_info()
self.logger = self.init_logger(work_dir, log_level)
self.log_buffer = LogBuffer()
self.mode = None
self._hooks = []
self._epoch = 0
self._iter = 0
self._inner_iter = 0
self._max_epochs = 0
self._max_iters = 0
@property
def model_name(self):
"""str: Name of the model, usually the module class name."""
return self._model_name
@property
def rank(self):
"""int: Rank of current process. (distributed training)"""
return self._rank
@property
def world_size(self):
"""int: Number of processes participating in the job.
(distributed training)"""
return self._world_size
@property
def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
def epoch(self):
"""int: Current epoch."""
return self._epoch
@property
def iter(self):
"""int: Current iteration."""
return self._iter
@property
def inner_iter(self):
"""int: Iteration in an epoch."""
return self._inner_iter
@property
def max_epochs(self):
"""int: Maximum training epochs."""
return self._max_epochs
@property
def max_iters(self):
"""int: Maximum training iterations."""
return self._max_iters
def init_optimizer(self, optimizer):
"""Init the optimizer.
Args:
optimizer (dict or :obj:`~torch.optim.Optimizer`): Either an
optimizer object or a dict used for constructing the optimizer.
An example of the dict: ``{'algorithm': 'SGD', 'lr': 0.02,
'momentum': 0.9, 'weight_decay': 0.0001}``.
Returns:
:obj:`~torch.optim.Optimizer`: An optimizer object.
"""
if isinstance(optimizer, dict):
optimizer = obj_from_dict(
optimizer, torch.optim, dict(params=self.model.parameters()))
elif not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError(
'optimizer must be either an Optimizer object or a dict, '
'but got {}'.format(type(optimizer)))
return optimizer
def init_logger(self, log_dir=None, level=logging.INFO):
"""Init the logger.
Args:
log_dir(str, optional): Log file directory. If not specified, no
log file will be used.
level (int or str): See the built-in python logging module.
Returns:
:obj:`~logging.Logger`: Python logger.
"""
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=level)
logger = logging.getLogger(__name__)
if log_dir:
filename = '{}_{}.log'.format(get_time_str(), self.rank)
log_file = osp.join(log_dir, filename)
add_file_handler(logger, log_file, level=level)
return logger
def current_lr(self):
"""Get current learning rates.
Returns:
list: Current learning rate of all param groups.
"""
return [group['lr'] for group in self.optimizer.param_groups]
def register_hook(self, hook, priority=50):
"""Register a hook into the hook list.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int): Hook priority. Lower value means higher priority.
"""
assert isinstance(hook, Hook)
assert isinstance(priority, int) and priority >= 0 and priority <= 100
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
hook.priority = priority
# insert the hook to a sorted list
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
def load_checkpoint(self, filename, map_location='cpu', strict=False):
self.logger.info('load checkpoint from %s', filename)
return load_checkpoint(self.model, filename, map_location, strict,
self.logger)
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None):
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
else:
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = osp.join(out_dir, filename_tmpl.format(self.epoch))
linkname = osp.join(out_dir, 'latest.pth')
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filename, optimizer=optimizer, meta=meta)
mmcv.symlink(filename, linkname)
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def resume(self, checkpoint, resume_optimizer=True,
map_location='default'):
if map_location == 'default':
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location)
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
if 'optimizer' in checkpoint and resume_optimizer:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def run(self, data_loaders, workflow, max_epochs, **kwargs):
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
'runner has no method named "{}" to run an epoch'.
format(mode))
epoch_runner = getattr(self, mode)
elif callable(mode): # custom train()
epoch_runner = mode
else:
raise TypeError('mode in workflow must be a str or '
'callable function, not {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def register_lr_hooks(self, lr_config):
if isinstance(lr_config, LrUpdaterHook):
self.register_hook(lr_config)
elif isinstance(lr_config, dict):
assert 'policy' in lr_config
from ..hooks import lr_updater
hook_name = lr_config['policy'].title() + 'LrUpdaterHook'
if not hasattr(lr_updater, hook_name):
raise ValueError('"{}" does not exist'.format(hook_name))
hook_cls = getattr(lr_updater, hook_name)
self.register_hook(hook_cls(**lr_config))
else:
raise TypeError('"lr_config" must be either a LrUpdaterHook object'
' or dict, not {}'.format(type(lr_config)))
def register_logger_hooks(self, log_config):
log_interval = log_config['interval']
for info in log_config['hooks']:
logger_hook = obj_from_dict(
info, hooks, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority=60)
def register_default_hooks(self,
lr_config,
grad_clip_config=None,
checkpoint_config=None,
log_config=None):
"""Register several default hooks.
Default hooks include:
- LrUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook
"""
if grad_clip_config is None:
grad_clip_config = {}
if checkpoint_config is None:
checkpoint_config = {}
self.register_lr_hooks(lr_config)
self.register_hook(OptimizerStepperHook(**grad_clip_config))
self.register_hook(CheckpointSaverHook(**checkpoint_config))
self.register_hook(IterTimerHook())
if log_config is not None:
self.register_logger_hooks(log_config)
import functools
import logging
import time
from getpass import getuser
from socket import gethostname
import mmcv
import torch.distributed as dist
def get_host_info():
return '{}@{}'.format(getuser(), gethostname())
def get_dist_info():
if dist._initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def add_file_handler(logger, filename=None, mode='w', level=logging.INFO):
file_handler = logging.FileHandler(filename, mode)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
return logger
def obj_from_dict(info, module, default_args=None):
"""Initialize an object from dict.
The dict must contain the key "type", which indicates the object type, it
can be either a string or type, such as "list" or ``list``. Remaining
fields are treated as the arguments for constructing the object.
Args:
info (dict): Object types and arguments.
module (:class:`module`): Module which may containing expected object
classes.
default_args (dict, optional): Default arguments for initializing the
object.
Returns:
"""
assert isinstance(info, dict) and 'type' in info
assert isinstance(default_args, dict) or default_args is None
args = info.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
obj_type = getattr(module, obj_type)
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment