Unverified Commit b34a8432 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Init logger before constructing Runner (#1865)

* init logger before constructing Runner

* use mmdet logger for loading checkpoints

* bug fix for abstract methods
parent 141b6c98
...@@ -24,15 +24,24 @@ def set_random_seed(seed): ...@@ -24,15 +24,24 @@ def set_random_seed(seed):
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def get_root_logger(log_level=logging.INFO): def get_root_logger(log_file=None, log_level=logging.INFO):
logger = logging.getLogger() logger = logging.getLogger('mmdet')
if not logger.hasHandlers(): # if the logger has been initialized, just return it
logging.basicConfig( if logger.hasHandlers():
format='%(asctime)s - %(levelname)s - %(message)s', return logger
level=log_level)
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
rank, _ = get_dist_info() rank, _ = get_dist_info()
if rank != 0: if rank != 0:
logger.setLevel('ERROR') logger.setLevel('ERROR')
elif log_file is not None:
file_handler = logging.FileHandler(log_file, 'w')
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
return logger return logger
...@@ -75,15 +84,26 @@ def train_detector(model, ...@@ -75,15 +84,26 @@ def train_detector(model,
cfg, cfg,
distributed=False, distributed=False,
validate=False, validate=False,
logger=None): timestamp=None):
if logger is None: logger = get_root_logger(cfg.log_level)
logger = get_root_logger(cfg.log_level)
# start training # start training
if distributed: if distributed:
_dist_train(model, dataset, cfg, validate=validate) _dist_train(
model,
dataset,
cfg,
validate=validate,
logger=logger,
timestamp=timestamp)
else: else:
_non_dist_train(model, dataset, cfg, validate=validate) _non_dist_train(
model,
dataset,
cfg,
validate=validate,
logger=logger,
timestamp=timestamp)
def build_optimizer(model, optimizer_cfg): def build_optimizer(model, optimizer_cfg):
...@@ -166,7 +186,12 @@ def build_optimizer(model, optimizer_cfg): ...@@ -166,7 +186,12 @@ def build_optimizer(model, optimizer_cfg):
return optimizer_cls(params, **optimizer_cfg) return optimizer_cls(params, **optimizer_cfg)
def _dist_train(model, dataset, cfg, validate=False): def _dist_train(model,
dataset,
cfg,
validate=False,
logger=None,
timestamp=None):
# prepare data loaders # prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [ data_loaders = [
...@@ -179,8 +204,10 @@ def _dist_train(model, dataset, cfg, validate=False): ...@@ -179,8 +204,10 @@ def _dist_train(model, dataset, cfg, validate=False):
# build runner # build runner
optimizer = build_optimizer(model, cfg.optimizer) optimizer = build_optimizer(model, cfg.optimizer)
runner = Runner(model, batch_processor, optimizer, cfg.work_dir, runner = Runner(
cfg.log_level) model, batch_processor, optimizer, cfg.work_dir, logger=logger)
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting # fp16 setting
fp16_cfg = cfg.get('fp16', None) fp16_cfg = cfg.get('fp16', None)
...@@ -218,7 +245,12 @@ def _dist_train(model, dataset, cfg, validate=False): ...@@ -218,7 +245,12 @@ def _dist_train(model, dataset, cfg, validate=False):
runner.run(data_loaders, cfg.workflow, cfg.total_epochs) runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
def _non_dist_train(model, dataset, cfg, validate=False): def _non_dist_train(model,
dataset,
cfg,
validate=False,
logger=None,
timestamp=None):
if validate: if validate:
raise NotImplementedError('Built-in validation is not implemented ' raise NotImplementedError('Built-in validation is not implemented '
'yet in not-distributed training. Use ' 'yet in not-distributed training. Use '
...@@ -239,8 +271,10 @@ def _non_dist_train(model, dataset, cfg, validate=False): ...@@ -239,8 +271,10 @@ def _non_dist_train(model, dataset, cfg, validate=False):
# build runner # build runner
optimizer = build_optimizer(model, cfg.optimizer) optimizer = build_optimizer(model, cfg.optimizer)
runner = Runner(model, batch_processor, optimizer, cfg.work_dir, runner = Runner(
cfg.log_level) model, batch_processor, optimizer, cfg.work_dir, logger=logger)
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting # fp16 setting
fp16_cfg = cfg.get('fp16', None) fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None: if fp16_cfg is not None:
......
import logging
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
...@@ -462,7 +460,8 @@ class HRNet(nn.Module): ...@@ -462,7 +460,8 @@ class HRNet(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
for m in self.modules(): for m in self.modules():
......
import logging
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.cnn import constant_init, kaiming_init from mmcv.cnn import constant_init, kaiming_init
...@@ -495,7 +493,8 @@ class ResNet(nn.Module): ...@@ -495,7 +493,8 @@ class ResNet(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
for m in self.modules(): for m in self.modules():
......
import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -75,7 +73,8 @@ class SSDVGG(VGG): ...@@ -75,7 +73,8 @@ class SSDVGG(VGG):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
for m in self.features.modules(): for m in self.features.modules():
......
import logging
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import mmcv import mmcv
...@@ -9,11 +8,9 @@ import torch.nn as nn ...@@ -9,11 +8,9 @@ import torch.nn as nn
from mmdet.core import auto_fp16, get_classes, tensor2imgs from mmdet.core import auto_fp16, get_classes, tensor2imgs
class BaseDetector(nn.Module): class BaseDetector(nn.Module, metaclass=ABCMeta):
"""Base class for detectors""" """Base class for detectors"""
__metaclass__ = ABCMeta
def __init__(self): def __init__(self):
super(BaseDetector, self).__init__() super(BaseDetector, self).__init__()
self.fp16_enabled = False self.fp16_enabled = False
...@@ -61,9 +58,8 @@ class BaseDetector(nn.Module): ...@@ -61,9 +58,8 @@ class BaseDetector(nn.Module):
""" """
pass pass
@abstractmethod
async def async_simple_test(self, img, img_meta, **kwargs): async def async_simple_test(self, img, img_meta, **kwargs):
pass raise NotImplementedError
@abstractmethod @abstractmethod
def simple_test(self, img, img_meta, **kwargs): def simple_test(self, img, img_meta, **kwargs):
...@@ -75,7 +71,8 @@ class BaseDetector(nn.Module): ...@@ -75,7 +71,8 @@ class BaseDetector(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if pretrained is not None: if pretrained is not None:
logger = logging.getLogger() from mmdet.apis import get_root_logger
logger = get_root_logger()
logger.info('load model from: {}'.format(pretrained)) logger.info('load model from: {}'.format(pretrained))
async def aforward_test(self, *, img, img_meta, **kwargs): async def aforward_test(self, *, img, img_meta, **kwargs):
......
import logging
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
...@@ -47,7 +45,8 @@ class ResLayer(nn.Module): ...@@ -47,7 +45,8 @@ class ResLayer(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
for m in self.modules(): for m in self.modules():
......
from __future__ import division from __future__ import division
import argparse import argparse
import os import os
import os.path as osp
import time
import mmcv
import torch import torch
from mmcv import Config from mmcv import Config
from mmcv.runner import init_dist from mmcv.runner import init_dist
...@@ -71,11 +74,17 @@ def main(): ...@@ -71,11 +74,17 @@ def main():
distributed = True distributed = True
init_dist(args.launcher, **cfg.dist_params) init_dist(args.launcher, **cfg.dist_params)
# init logger before other steps # create work_dir
logger = get_root_logger(cfg.log_level) mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# log some basic info
logger.info('Distributed training: {}'.format(distributed)) logger.info('Distributed training: {}'.format(distributed))
logger.info('MMDetection Version: {}'.format(__version__)) logger.info('MMDetection Version: {}'.format(__version__))
logger.info('Config: {}'.format(cfg.text)) logger.info('Config:\n{}'.format(cfg.text))
# set random seeds # set random seeds
if args.seed is not None: if args.seed is not None:
...@@ -103,7 +112,7 @@ def main(): ...@@ -103,7 +112,7 @@ def main():
cfg, cfg,
distributed=distributed, distributed=distributed,
validate=args.validate, validate=args.validate,
logger=logger) timestamp=timestamp)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment