Unverified Commit 93bed07b authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Add logger utils (#2035)

* add logger utils

* replace get_root_logger() and logger.info() with print_log()

* fix a typo

* minor fix for the format of StreamHandler
parent bc75766c
import logging
import random
import re
from collections import OrderedDict
......@@ -7,35 +6,14 @@ import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, Runner, get_dist_info,
obj_from_dict)
from mmcv.runner import DistSamplerSeedHook, Runner, obj_from_dict
from mmdet import datasets
from mmdet.core import (CocoDistEvalmAPHook, CocoDistEvalRecallHook,
DistEvalmAPHook, DistOptimizerHook, Fp16OptimizerHook)
from mmdet.datasets import DATASETS, build_dataloader
from mmdet.models import RPN
def get_root_logger(log_file=None, log_level=logging.INFO):
logger = logging.getLogger('mmdet')
# if the logger has been initialized, just return it
if logger.hasHandlers():
return logger
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
elif log_file is not None:
file_handler = logging.FileHandler(log_file, 'w')
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
return logger
from mmdet.utils import get_root_logger
def set_random_seed(seed, deterministic=False):
......
import logging
from multiprocessing import Pool
import mmcv
import numpy as np
from terminaltables import AsciiTable
from mmdet.utils import print_log
from .bbox_overlaps import bbox_overlaps
from .class_names import get_classes
......@@ -268,7 +268,7 @@ def eval_map(det_results,
scale_ranges=None,
iou_thr=0.5,
dataset=None,
logger='default',
logger=None,
nproc=4):
"""Evaluate mAP of a dataset.
......@@ -291,11 +291,8 @@ def eval_map(det_results,
dataset (list[str] | str | None): Dataset name or dataset classes,
there are minor differences in metrics for different datsets, e.g.
"voc07", "imagenet_det", etc. Default: None.
logger (logging.Logger | 'print' | None): The way to print the mAP
summary. If a Logger is specified, then the summary will be logged
with `logger.info()`; if set to "print", then it will be simply
printed to stdout; if set to None, then no information will be
printed. Default: 'print'.
logger (logging.Logger | str | None): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None.
nproc (int): Processes used for computing TP and FP.
Default: 4.
......@@ -383,9 +380,9 @@ def eval_map(det_results,
if cls_result['num_gts'] > 0:
aps.append(cls_result['ap'])
mean_ap = np.array(aps).mean().item() if aps else 0.0
if logger is not None:
print_map_summary(
mean_ap, eval_results, dataset, area_ranges, logger=logger)
print_map_summary(
mean_ap, eval_results, dataset, area_ranges, logger=logger)
return mean_ap, eval_results
......@@ -405,18 +402,12 @@ def print_map_summary(mean_ap,
results (list[dict]): Calculated from `eval_map()`.
dataset (list[str] | str | None): Dataset name or dataset classes.
scale_ranges (list[tuple] | None): Range of scales to be evaluated.
logger (logging.Logger | 'print' | None): The way to print the mAP
summary. If a Logger is specified, then the summary will be logged
with `logger.info()`; if set to "print", then it will be simply
printed to stdout; if set to None, then no information will be
printed. Default: 'print'.
logger (logging.Logger | str | None): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None.
"""
def _print(content):
if logger == 'print':
print(content)
elif isinstance(logger, logging.Logger):
logger.info(content)
if logger == 'silent':
return
if isinstance(results[0]['ap'], np.ndarray):
num_scales = len(results[0]['ap'])
......@@ -426,9 +417,6 @@ def print_map_summary(mean_ap,
if scale_ranges is not None:
assert len(scale_ranges) == num_scales
assert logger is None or logger == 'print' or isinstance(
logger, logging.Logger)
num_classes = len(results)
recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
......@@ -453,7 +441,7 @@ def print_map_summary(mean_ap,
header = ['class', 'gts', 'dets', 'recall', 'ap']
for i in range(num_scales):
if scale_ranges is not None:
_print('Scale range ', scale_ranges[i])
print_log('Scale range {}'.format(scale_ranges[i]), logger=logger)
table_data = [header]
for j in range(num_classes):
row_data = [
......@@ -464,4 +452,4 @@ def print_map_summary(mean_ap,
table_data.append(['mAP', '', '', '', '{:.3f}'.format(mean_ap[i])])
table = AsciiTable(table_data)
table.inner_footing_row_border = True
_print('\n' + table.table)
print_log('\n' + table.table, logger=logger)
......@@ -3,6 +3,7 @@ from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.utils import get_root_logger
from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer
from .resnet import BasicBlock, Bottleneck
......@@ -460,7 +461,6 @@ class HRNet(nn.Module):
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
......
......@@ -6,6 +6,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models.plugins import GeneralizedAttention
from mmdet.ops import ContextBlock
from mmdet.utils import get_root_logger
from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer
......@@ -468,7 +469,6 @@ class ResNet(nn.Module):
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
......
......@@ -4,6 +4,7 @@ import torch.nn.functional as F
from mmcv.cnn import VGG, constant_init, kaiming_init, normal_init, xavier_init
from mmcv.runner import load_checkpoint
from mmdet.utils import get_root_logger
from ..registry import BACKBONES
......@@ -73,7 +74,6 @@ class SSDVGG(VGG):
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
......
......@@ -6,6 +6,7 @@ import pycocotools.mask as maskUtils
import torch.nn as nn
from mmdet.core import auto_fp16, get_classes, tensor2imgs
from mmdet.utils import print_log
class BaseDetector(nn.Module, metaclass=ABCMeta):
......@@ -71,9 +72,7 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
def init_weights(self, pretrained=None):
if pretrained is not None:
from mmdet.apis import get_root_logger
logger = get_root_logger()
logger.info('load model from: {}'.format(pretrained))
print_log('load model from: {}'.format(pretrained), logger='root')
async def aforward_test(self, *, img, img_meta, **kwargs):
for var, name in [(img, 'img'), (img_meta, 'img_meta')]:
......
......@@ -3,6 +3,7 @@ from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmdet.core import auto_fp16
from mmdet.utils import get_root_logger
from ..backbones import ResNet, make_res_layer
from ..registry import SHARED_HEADS
......@@ -45,7 +46,6 @@ class ResLayer(nn.Module):
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
from mmdet.apis import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
......
......@@ -6,6 +6,7 @@ from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single
from mmdet.utils import print_log
from . import deform_conv_cuda
......@@ -297,10 +298,10 @@ class DeformConvPack(DeformConv):
'_offset.bias')
if version is not None and version > 1:
from mmdet.apis import get_root_logger
logger = get_root_logger()
logger.info('DeformConvPack {} is upgraded to version 2.'.format(
prefix.rstrip('.')))
print_log(
'DeformConvPack {} is upgraded to version 2.'.format(
prefix.rstrip('.')),
logger='root')
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
......@@ -420,11 +421,10 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
'_offset.bias')
if version is not None and version > 1:
from mmdet.apis import get_root_logger
logger = get_root_logger()
logger.info(
print_log(
'ModulatedDeformConvPack {} is upgraded to version 2.'.format(
prefix.rstrip('.')))
prefix.rstrip('.')),
logger='root')
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
......
from .flops_counter import get_model_complexity_info
from .logger import get_root_logger, print_log
from .registry import Registry, build_from_cfg
__all__ = ['Registry', 'build_from_cfg', 'get_model_complexity_info']
__all__ = [
'Registry', 'build_from_cfg', 'get_model_complexity_info',
'get_root_logger', 'print_log'
]
import logging
from mmcv.runner import get_dist_info
def get_root_logger(log_file=None, log_level=logging.INFO):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added. The name of the root logger is the top-level package name,
e.g., "mmdet".
Args:
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
logger = logging.getLogger(__name__.split('.')[0]) # i.e., mmdet
# if the logger has been initialized, just return it
if logger.hasHandlers():
return logger
format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(format=format_str, level=log_level)
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
elif log_file is not None:
file_handler = logging.FileHandler(log_file, 'w')
file_handler.setFormatter(logging.Formatter(format_str))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
return logger
def print_log(msg, logger=None, level=logging.INFO):
"""Print a log message.
Args:
msg (str): The message to be logged.
logger (logging.Logger | str | None): The logger to be used. Some
special loggers are:
- "root": the root logger obtained with `get_root_logger()`.
- "silent": no message will be printed.
- None: The `print()` method will be used to print log messages.
level (int): Logging level. Only available when `logger` is a Logger
object or "root".
"""
if logger is None:
print(msg)
elif logger == 'root':
_logger = get_root_logger()
_logger.log(level, msg)
elif isinstance(logger, logging.Logger):
logger.log(level, msg)
elif logger != 'silent':
raise TypeError(
'logger should be either a logging.Logger object, "root", '
'"silent" or None, but got {}'.format(logger))
......@@ -10,9 +10,10 @@ from mmcv import Config
from mmcv.runner import init_dist
from mmdet import __version__
from mmdet.apis import get_root_logger, set_random_seed, train_detector
from mmdet.apis import set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import get_root_logger
def parse_args():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment