Unverified Commit 93bed07b authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Add logger utils (#2035)

* add logger utils

* replace get_root_logger() and logger.info() with print_log()

* fix a typo

* minor fix for the format of StreamHandler
parent bc75766c
import logging
import random import random
import re import re
from collections import OrderedDict from collections import OrderedDict
...@@ -7,35 +6,14 @@ import numpy as np ...@@ -7,35 +6,14 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, Runner, get_dist_info, from mmcv.runner import DistSamplerSeedHook, Runner, obj_from_dict
obj_from_dict)
from mmdet import datasets from mmdet import datasets
from mmdet.core import (CocoDistEvalmAPHook, CocoDistEvalRecallHook, from mmdet.core import (CocoDistEvalmAPHook, CocoDistEvalRecallHook,
DistEvalmAPHook, DistOptimizerHook, Fp16OptimizerHook) DistEvalmAPHook, DistOptimizerHook, Fp16OptimizerHook)
from mmdet.datasets import DATASETS, build_dataloader from mmdet.datasets import DATASETS, build_dataloader
from mmdet.models import RPN from mmdet.models import RPN
from mmdet.utils import get_root_logger
def get_root_logger(log_file=None, log_level=logging.INFO):
logger = logging.getLogger('mmdet')
# if the logger has been initialized, just return it
if logger.hasHandlers():
return logger
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
elif log_file is not None:
file_handler = logging.FileHandler(log_file, 'w')
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
return logger
def set_random_seed(seed, deterministic=False): def set_random_seed(seed, deterministic=False):
......
import logging
from multiprocessing import Pool from multiprocessing import Pool
import mmcv import mmcv
import numpy as np import numpy as np
from terminaltables import AsciiTable from terminaltables import AsciiTable
from mmdet.utils import print_log
from .bbox_overlaps import bbox_overlaps from .bbox_overlaps import bbox_overlaps
from .class_names import get_classes from .class_names import get_classes
...@@ -268,7 +268,7 @@ def eval_map(det_results, ...@@ -268,7 +268,7 @@ def eval_map(det_results,
scale_ranges=None, scale_ranges=None,
iou_thr=0.5, iou_thr=0.5,
dataset=None, dataset=None,
logger='default', logger=None,
nproc=4): nproc=4):
"""Evaluate mAP of a dataset. """Evaluate mAP of a dataset.
...@@ -291,11 +291,8 @@ def eval_map(det_results, ...@@ -291,11 +291,8 @@ def eval_map(det_results,
dataset (list[str] | str | None): Dataset name or dataset classes, dataset (list[str] | str | None): Dataset name or dataset classes,
there are minor differences in metrics for different datsets, e.g. there are minor differences in metrics for different datsets, e.g.
"voc07", "imagenet_det", etc. Default: None. "voc07", "imagenet_det", etc. Default: None.
logger (logging.Logger | 'print' | None): The way to print the mAP logger (logging.Logger | str | None): The way to print the mAP
summary. If a Logger is specified, then the summary will be logged summary. See `mmdet.utils.print_log()` for details. Default: None.
with `logger.info()`; if set to "print", then it will be simply
printed to stdout; if set to None, then no information will be
printed. Default: 'print'.
nproc (int): Processes used for computing TP and FP. nproc (int): Processes used for computing TP and FP.
Default: 4. Default: 4.
...@@ -383,7 +380,7 @@ def eval_map(det_results, ...@@ -383,7 +380,7 @@ def eval_map(det_results,
if cls_result['num_gts'] > 0: if cls_result['num_gts'] > 0:
aps.append(cls_result['ap']) aps.append(cls_result['ap'])
mean_ap = np.array(aps).mean().item() if aps else 0.0 mean_ap = np.array(aps).mean().item() if aps else 0.0
if logger is not None:
print_map_summary( print_map_summary(
mean_ap, eval_results, dataset, area_ranges, logger=logger) mean_ap, eval_results, dataset, area_ranges, logger=logger)
...@@ -405,18 +402,12 @@ def print_map_summary(mean_ap, ...@@ -405,18 +402,12 @@ def print_map_summary(mean_ap,
results (list[dict]): Calculated from `eval_map()`. results (list[dict]): Calculated from `eval_map()`.
dataset (list[str] | str | None): Dataset name or dataset classes. dataset (list[str] | str | None): Dataset name or dataset classes.
scale_ranges (list[tuple] | None): Range of scales to be evaluated. scale_ranges (list[tuple] | None): Range of scales to be evaluated.
logger (logging.Logger | 'print' | None): The way to print the mAP logger (logging.Logger | str | None): The way to print the mAP
summary. If a Logger is specified, then the summary will be logged summary. See `mmdet.utils.print_log()` for details. Default: None.
with `logger.info()`; if set to "print", then it will be simply
printed to stdout; if set to None, then no information will be
printed. Default: 'print'.
""" """
def _print(content): if logger == 'silent':
if logger == 'print': return
print(content)
elif isinstance(logger, logging.Logger):
logger.info(content)
if isinstance(results[0]['ap'], np.ndarray): if isinstance(results[0]['ap'], np.ndarray):
num_scales = len(results[0]['ap']) num_scales = len(results[0]['ap'])
...@@ -426,9 +417,6 @@ def print_map_summary(mean_ap, ...@@ -426,9 +417,6 @@ def print_map_summary(mean_ap,
if scale_ranges is not None: if scale_ranges is not None:
assert len(scale_ranges) == num_scales assert len(scale_ranges) == num_scales
assert logger is None or logger == 'print' or isinstance(
logger, logging.Logger)
num_classes = len(results) num_classes = len(results)
recalls = np.zeros((num_scales, num_classes), dtype=np.float32) recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
...@@ -453,7 +441,7 @@ def print_map_summary(mean_ap, ...@@ -453,7 +441,7 @@ def print_map_summary(mean_ap,
header = ['class', 'gts', 'dets', 'recall', 'ap'] header = ['class', 'gts', 'dets', 'recall', 'ap']
for i in range(num_scales): for i in range(num_scales):
if scale_ranges is not None: if scale_ranges is not None:
_print('Scale range ', scale_ranges[i]) print_log('Scale range {}'.format(scale_ranges[i]), logger=logger)
table_data = [header] table_data = [header]
for j in range(num_classes): for j in range(num_classes):
row_data = [ row_data = [
...@@ -464,4 +452,4 @@ def print_map_summary(mean_ap, ...@@ -464,4 +452,4 @@ def print_map_summary(mean_ap,
table_data.append(['mAP', '', '', '', '{:.3f}'.format(mean_ap[i])]) table_data.append(['mAP', '', '', '', '{:.3f}'.format(mean_ap[i])])
table = AsciiTable(table_data) table = AsciiTable(table_data)
table.inner_footing_row_border = True table.inner_footing_row_border = True
_print('\n' + table.table) print_log('\n' + table.table, logger=logger)
...@@ -3,6 +3,7 @@ from mmcv.cnn import constant_init, kaiming_init ...@@ -3,6 +3,7 @@ from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.utils import get_root_logger
from ..registry import BACKBONES from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer from ..utils import build_conv_layer, build_norm_layer
from .resnet import BasicBlock, Bottleneck from .resnet import BasicBlock, Bottleneck
...@@ -460,7 +461,6 @@ class HRNet(nn.Module): ...@@ -460,7 +461,6 @@ class HRNet(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
from mmdet.apis import get_root_logger
logger = get_root_logger() logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
......
...@@ -6,6 +6,7 @@ from torch.nn.modules.batchnorm import _BatchNorm ...@@ -6,6 +6,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models.plugins import GeneralizedAttention from mmdet.models.plugins import GeneralizedAttention
from mmdet.ops import ContextBlock from mmdet.ops import ContextBlock
from mmdet.utils import get_root_logger
from ..registry import BACKBONES from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer from ..utils import build_conv_layer, build_norm_layer
...@@ -468,7 +469,6 @@ class ResNet(nn.Module): ...@@ -468,7 +469,6 @@ class ResNet(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
from mmdet.apis import get_root_logger
logger = get_root_logger() logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
......
...@@ -4,6 +4,7 @@ import torch.nn.functional as F ...@@ -4,6 +4,7 @@ import torch.nn.functional as F
from mmcv.cnn import VGG, constant_init, kaiming_init, normal_init, xavier_init from mmcv.cnn import VGG, constant_init, kaiming_init, normal_init, xavier_init
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from mmdet.utils import get_root_logger
from ..registry import BACKBONES from ..registry import BACKBONES
...@@ -73,7 +74,6 @@ class SSDVGG(VGG): ...@@ -73,7 +74,6 @@ class SSDVGG(VGG):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
from mmdet.apis import get_root_logger
logger = get_root_logger() logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
......
...@@ -6,6 +6,7 @@ import pycocotools.mask as maskUtils ...@@ -6,6 +6,7 @@ import pycocotools.mask as maskUtils
import torch.nn as nn import torch.nn as nn
from mmdet.core import auto_fp16, get_classes, tensor2imgs from mmdet.core import auto_fp16, get_classes, tensor2imgs
from mmdet.utils import print_log
class BaseDetector(nn.Module, metaclass=ABCMeta): class BaseDetector(nn.Module, metaclass=ABCMeta):
...@@ -71,9 +72,7 @@ class BaseDetector(nn.Module, metaclass=ABCMeta): ...@@ -71,9 +72,7 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if pretrained is not None: if pretrained is not None:
from mmdet.apis import get_root_logger print_log('load model from: {}'.format(pretrained), logger='root')
logger = get_root_logger()
logger.info('load model from: {}'.format(pretrained))
async def aforward_test(self, *, img, img_meta, **kwargs): async def aforward_test(self, *, img, img_meta, **kwargs):
for var, name in [(img, 'img'), (img_meta, 'img_meta')]: for var, name in [(img, 'img'), (img_meta, 'img_meta')]:
......
...@@ -3,6 +3,7 @@ from mmcv.cnn import constant_init, kaiming_init ...@@ -3,6 +3,7 @@ from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from mmdet.core import auto_fp16 from mmdet.core import auto_fp16
from mmdet.utils import get_root_logger
from ..backbones import ResNet, make_res_layer from ..backbones import ResNet, make_res_layer
from ..registry import SHARED_HEADS from ..registry import SHARED_HEADS
...@@ -45,7 +46,6 @@ class ResLayer(nn.Module): ...@@ -45,7 +46,6 @@ class ResLayer(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
from mmdet.apis import get_root_logger
logger = get_root_logger() logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
......
...@@ -6,6 +6,7 @@ from torch.autograd import Function ...@@ -6,6 +6,7 @@ from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single from torch.nn.modules.utils import _pair, _single
from mmdet.utils import print_log
from . import deform_conv_cuda from . import deform_conv_cuda
...@@ -297,10 +298,10 @@ class DeformConvPack(DeformConv): ...@@ -297,10 +298,10 @@ class DeformConvPack(DeformConv):
'_offset.bias') '_offset.bias')
if version is not None and version > 1: if version is not None and version > 1:
from mmdet.apis import get_root_logger print_log(
logger = get_root_logger() 'DeformConvPack {} is upgraded to version 2.'.format(
logger.info('DeformConvPack {} is upgraded to version 2.'.format( prefix.rstrip('.')),
prefix.rstrip('.'))) logger='root')
super()._load_from_state_dict(state_dict, prefix, local_metadata, super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, strict, missing_keys, unexpected_keys,
...@@ -420,11 +421,10 @@ class ModulatedDeformConvPack(ModulatedDeformConv): ...@@ -420,11 +421,10 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
'_offset.bias') '_offset.bias')
if version is not None and version > 1: if version is not None and version > 1:
from mmdet.apis import get_root_logger print_log(
logger = get_root_logger()
logger.info(
'ModulatedDeformConvPack {} is upgraded to version 2.'.format( 'ModulatedDeformConvPack {} is upgraded to version 2.'.format(
prefix.rstrip('.'))) prefix.rstrip('.')),
logger='root')
super()._load_from_state_dict(state_dict, prefix, local_metadata, super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, strict, missing_keys, unexpected_keys,
......
from .flops_counter import get_model_complexity_info from .flops_counter import get_model_complexity_info
from .logger import get_root_logger, print_log
from .registry import Registry, build_from_cfg from .registry import Registry, build_from_cfg
__all__ = ['Registry', 'build_from_cfg', 'get_model_complexity_info'] __all__ = [
'Registry', 'build_from_cfg', 'get_model_complexity_info',
'get_root_logger', 'print_log'
]
import logging
from mmcv.runner import get_dist_info
def get_root_logger(log_file=None, log_level=logging.INFO):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added. The name of the root logger is the top-level package name,
e.g., "mmdet".
Args:
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
logger = logging.getLogger(__name__.split('.')[0]) # i.e., mmdet
# if the logger has been initialized, just return it
if logger.hasHandlers():
return logger
format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(format=format_str, level=log_level)
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
elif log_file is not None:
file_handler = logging.FileHandler(log_file, 'w')
file_handler.setFormatter(logging.Formatter(format_str))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
return logger
def print_log(msg, logger=None, level=logging.INFO):
"""Print a log message.
Args:
msg (str): The message to be logged.
logger (logging.Logger | str | None): The logger to be used. Some
special loggers are:
- "root": the root logger obtained with `get_root_logger()`.
- "silent": no message will be printed.
- None: The `print()` method will be used to print log messages.
level (int): Logging level. Only available when `logger` is a Logger
object or "root".
"""
if logger is None:
print(msg)
elif logger == 'root':
_logger = get_root_logger()
_logger.log(level, msg)
elif isinstance(logger, logging.Logger):
logger.log(level, msg)
elif logger != 'silent':
raise TypeError(
'logger should be either a logging.Logger object, "root", '
'"silent" or None, but got {}'.format(logger))
...@@ -10,9 +10,10 @@ from mmcv import Config ...@@ -10,9 +10,10 @@ from mmcv import Config
from mmcv.runner import init_dist from mmcv.runner import init_dist
from mmdet import __version__ from mmdet import __version__
from mmdet.apis import get_root_logger, set_random_seed, train_detector from mmdet.apis import set_random_seed, train_detector
from mmdet.datasets import build_dataset from mmdet.datasets import build_dataset
from mmdet.models import build_detector from mmdet.models import build_detector
from mmdet.utils import get_root_logger
def parse_args(): def parse_args():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment