Unverified Commit b34a8432 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Init logger before constructing Runner (#1865)

* init logger before constructing Runner

* use mmdet logger for loading checkpoints

* bug fix for abstract methods
parent 141b6c98
......@@ -24,15 +24,24 @@ def set_random_seed(seed):
torch.cuda.manual_seed_all(seed)
def get_root_logger(log_level=logging.INFO):
logger = logging.getLogger()
if not logger.hasHandlers():
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s',
level=log_level)
def get_root_logger(log_file=None, log_level=logging.INFO):
logger = logging.getLogger('mmdet')
# if the logger has been initialized, just return it
if logger.hasHandlers():
return logger
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
elif log_file is not None:
file_handler = logging.FileHandler(log_file, 'w')
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
return logger
......@@ -75,15 +84,26 @@ def train_detector(model,
cfg,
distributed=False,
validate=False,
logger=None):
if logger is None:
logger = get_root_logger(cfg.log_level)
timestamp=None):
logger = get_root_logger(cfg.log_level)
# start training
if distributed:
_dist_train(model, dataset, cfg, validate=validate)
_dist_train(
model,
dataset,
cfg,
validate=validate,
logger=logger,
timestamp=timestamp)
else:
_non_dist_train(model, dataset, cfg, validate=validate)
_non_dist_train(
model,
dataset,
cfg,
validate=validate,
logger=logger,
timestamp=timestamp)
def build_optimizer(model, optimizer_cfg):
......@@ -166,7 +186,12 @@ def build_optimizer(model, optimizer_cfg):
return optimizer_cls(params, **optimizer_cfg)
def _dist_train(model, dataset, cfg, validate=False):
def _dist_train(model,
dataset,
cfg,
validate=False,
logger=None,
timestamp=None):
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
......@@ -179,8 +204,10 @@ def _dist_train(model, dataset, cfg, validate=False):
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
cfg.log_level)
runner = Runner(
model, batch_processor, optimizer, cfg.work_dir, logger=logger)
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
......@@ -218,7 +245,12 @@ def _dist_train(model, dataset, cfg, validate=False):
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
def _non_dist_train(model, dataset, cfg, validate=False):
def _non_dist_train(model,
dataset,
cfg,
validate=False,
logger=None,
timestamp=None):
if validate:
raise NotImplementedError('Built-in validation is not implemented '
'yet in not-distributed training. Use '
......@@ -239,8 +271,10 @@ def _non_dist_train(model, dataset, cfg, validate=False):
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
cfg.log_level)
runner = Runner(
model, batch_processor, optimizer, cfg.work_dir, logger=logger)
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
......
import logging
import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
......@@ -462,7 +460,8 @@ class HRNet(nn.Module):
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
......
import logging
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import constant_init, kaiming_init
......@@ -495,7 +493,8 @@ class ResNet(nn.Module):
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
......
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -75,7 +73,8 @@ class SSDVGG(VGG):
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.features.modules():
......
import logging
from abc import ABCMeta, abstractmethod
import mmcv
......@@ -9,11 +8,9 @@ import torch.nn as nn
from mmdet.core import auto_fp16, get_classes, tensor2imgs
class BaseDetector(nn.Module):
class BaseDetector(nn.Module, metaclass=ABCMeta):
"""Base class for detectors"""
__metaclass__ = ABCMeta
def __init__(self):
super(BaseDetector, self).__init__()
self.fp16_enabled = False
......@@ -61,9 +58,8 @@ class BaseDetector(nn.Module):
"""
pass
@abstractmethod
async def async_simple_test(self, img, img_meta, **kwargs):
pass
raise NotImplementedError
@abstractmethod
def simple_test(self, img, img_meta, **kwargs):
......@@ -75,7 +71,8 @@ class BaseDetector(nn.Module):
def init_weights(self, pretrained=None):
if pretrained is not None:
logger = logging.getLogger()
from mmdet.apis import get_root_logger
logger = get_root_logger()
logger.info('load model from: {}'.format(pretrained))
async def aforward_test(self, *, img, img_meta, **kwargs):
......
import logging
import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
......@@ -47,7 +45,8 @@ class ResLayer(nn.Module):
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
......
from __future__ import division
import argparse
import os
import os.path as osp
import time
import mmcv
import torch
from mmcv import Config
from mmcv.runner import init_dist
......@@ -71,11 +74,17 @@ def main():
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# init logger before other steps
logger = get_root_logger(cfg.log_level)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# log some basic info
logger.info('Distributed training: {}'.format(distributed))
logger.info('MMDetection Version: {}'.format(__version__))
logger.info('Config: {}'.format(cfg.text))
logger.info('Config:\n{}'.format(cfg.text))
# set random seeds
if args.seed is not None:
......@@ -103,7 +112,7 @@ def main():
cfg,
distributed=distributed,
validate=args.validate,
logger=logger)
timestamp=timestamp)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment