Commit d0fb2a8d authored by Kai Chen's avatar Kai Chen
Browse files

suppress logging for processes whose rank > 0

parent 5421859a
import logging
import math import math
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.torchpack import load_checkpoint from mmcv.torchpack import load_checkpoint
...@@ -241,7 +243,8 @@ class ResNet(nn.Module): ...@@ -241,7 +243,8 @@ class ResNet(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
load_checkpoint(self, pretrained, strict=False) logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
......
import logging
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
...@@ -12,10 +13,6 @@ class BaseDetector(nn.Module): ...@@ -12,10 +13,6 @@ class BaseDetector(nn.Module):
def __init__(self): def __init__(self):
super(BaseDetector, self).__init__() super(BaseDetector, self).__init__()
@abstractmethod
def init_weights(self):
pass
@abstractmethod @abstractmethod
def extract_feat(self, imgs): def extract_feat(self, imgs):
pass pass
...@@ -39,6 +36,11 @@ class BaseDetector(nn.Module): ...@@ -39,6 +36,11 @@ class BaseDetector(nn.Module):
def aug_test(self, imgs, img_metas, **kwargs): def aug_test(self, imgs, img_metas, **kwargs):
pass pass
def init_weights(self, pretrained=None):
if pretrained is not None:
logger = logging.getLogger()
logger.info('load model from: {}'.format(pretrained))
def forward_test(self, imgs, img_metas, **kwargs): def forward_test(self, imgs, img_metas, **kwargs):
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list): if not isinstance(var, list):
......
...@@ -24,8 +24,7 @@ class RPN(BaseDetector, RPNTestMixin): ...@@ -24,8 +24,7 @@ class RPN(BaseDetector, RPNTestMixin):
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if pretrained is not None: super(RPN, self).init_weights(pretrained)
print('load model from: {}'.format(pretrained))
self.backbone.init_weights(pretrained=pretrained) self.backbone.init_weights(pretrained=pretrained)
if self.neck is not None: if self.neck is not None:
self.neck.init_weights() self.neck.init_weights()
......
...@@ -24,10 +24,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -24,10 +24,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
super(TwoStageDetector, self).__init__() super(TwoStageDetector, self).__init__()
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
self.with_neck = True if neck is not None else False if neck is not None:
assert self.with_neck, "TwoStageDetector must be implemented with FPN now." self.with_neck = True
if self.with_neck:
self.neck = builder.build_neck(neck) self.neck = builder.build_neck(neck)
else:
raise NotImplementedError
self.with_rpn = True if rpn_head is not None else False self.with_rpn = True if rpn_head is not None else False
if self.with_rpn: if self.with_rpn:
...@@ -51,8 +52,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -51,8 +52,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if pretrained is not None: super(TwoStageDetector, self).init_weights(pretrained)
print('load model from: {}'.format(pretrained))
self.backbone.init_weights(pretrained=pretrained) self.backbone.init_weights(pretrained=pretrained)
if self.with_neck: if self.with_neck:
if isinstance(self.neck, nn.Sequential): if isinstance(self.neck, nn.Sequential):
...@@ -104,9 +104,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -104,9 +104,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
pos_gt_labels) = multi_apply( pos_gt_labels) = multi_apply(
self.bbox_roi_extractor.sample_proposals, proposal_list, self.bbox_roi_extractor.sample_proposals, proposal_list,
gt_bboxes, gt_bboxes_ignore, gt_labels, rcnn_train_cfg_list) gt_bboxes, gt_bboxes_ignore, gt_labels, rcnn_train_cfg_list)
labels, label_weights, bbox_targets, bbox_weights = \ (labels, label_weights, bbox_targets,
self.bbox_head.get_bbox_target(pos_proposals, neg_proposals, bbox_weights) = self.bbox_head.get_bbox_target(
pos_gt_bboxes, pos_gt_labels, self.train_cfg.rcnn) pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
self.train_cfg.rcnn)
rois = bbox2roi([ rois = bbox2roi([
torch.cat([pos, neg], dim=0) torch.cat([pos, neg], dim=0)
...@@ -139,7 +140,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -139,7 +140,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
def simple_test(self, img, img_meta, proposals=None, rescale=False): def simple_test(self, img, img_meta, proposals=None, rescale=False):
"""Test without augmentation.""" """Test without augmentation."""
assert proposals == None, "Fast RCNN hasn't been implemented." assert proposals is None, "Fast RCNN hasn't been implemented."
assert self.with_bbox, "Bbox head must be implemented." assert self.with_bbox, "Bbox head must be implemented."
x = self.extract_feat(img) x = self.extract_feat(img)
...@@ -152,12 +153,12 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -152,12 +153,12 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
bbox_results = bbox2result(det_bboxes, det_labels, bbox_results = bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes) self.bbox_head.num_classes)
if self.with_mask: if not self.with_mask:
return bbox_results
else:
segm_results = self.simple_test_mask( segm_results = self.simple_test_mask(
x, img_meta, det_bboxes, det_labels, rescale=rescale) x, img_meta, det_bboxes, det_labels, rescale=rescale)
return bbox_results, segm_results return bbox_results, segm_results
else:
return bbox_results
def aug_test(self, imgs, img_metas, rescale=False): def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations. """Test with augmentations.
...@@ -165,7 +166,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -165,7 +166,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
If rescale is False, then returned bboxes and masks will fit the scale If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0]. of imgs[0].
""" """
# recompute self.extract_feats(imgs) because of 'yield' and memory # recompute feats to save memory
proposal_list = self.aug_test_rpn( proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn) self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
det_bboxes, det_labels = self.aug_test_bboxes( det_bboxes, det_labels = self.aug_test_bboxes(
...@@ -183,10 +184,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -183,10 +184,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# det_bboxes always keep the original scale # det_bboxes always keep the original scale
if self.with_mask: if self.with_mask:
segm_results = self.aug_test_mask( segm_results = self.aug_test_mask(
self.extract_feats(imgs), self.extract_feats(imgs), img_metas, det_bboxes, det_labels)
img_metas,
det_bboxes,
det_labels)
return bbox_results, segm_results return bbox_results, segm_results
else: else:
return bbox_results return bbox_results
...@@ -114,4 +114,4 @@ log_level = 'INFO' ...@@ -114,4 +114,4 @@ log_level = 'INFO'
work_dir = './work_dirs/fpn_rpn_r50_1x' work_dir = './work_dirs/fpn_rpn_r50_1x'
load_from = None load_from = None
resume_from = None resume_from = None
workflow = [('train', 1), ('val', 1)] workflow = [('train', 1)]
...@@ -2,4 +2,4 @@ ...@@ -2,4 +2,4 @@
PYTHON=${PYTHON:-"python"} PYTHON=${PYTHON:-"python"}
$PYTHON -m torch.distributed.launch --nproc_per_node=$2 train.py $1 --launcher pytorch $3 $PYTHON -m torch.distributed.launch --nproc_per_node=$2 train.py $1 --launcher pytorch ${@:3}
from __future__ import division from __future__ import division
import argparse import argparse
import logging
from collections import OrderedDict from collections import OrderedDict
import torch import torch
...@@ -45,9 +46,17 @@ def batch_processor(model, data, train_mode): ...@@ -45,9 +46,17 @@ def batch_processor(model, data, train_mode):
return outputs return outputs
def get_logger(log_level):
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
logger = logging.getLogger()
return logger
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Train a detector') parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path') parser.add_argument('config', help='train config file path')
parser.add_argument('--work_dir', help='the dir to save logs and models')
parser.add_argument( parser.add_argument(
'--validate', '--validate',
action='store_true', action='store_true',
...@@ -69,16 +78,22 @@ def main(): ...@@ -69,16 +78,22 @@ def main():
args = parse_args() args = parse_args()
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
cfg.update(gpus=args.gpus) if args.work_dir is not None:
cfg.work_dir = args.work_dir
cfg.gpus = args.gpus
logger = get_logger(cfg.log_level)
# init distributed environment if necessary # init distributed environment if necessary
if args.launcher == 'none': if args.launcher == 'none':
dist = False dist = False
print('Disabled distributed training.') logger.info('Disabled distributed training.')
else: else:
dist = True dist = True
print('Enabled distributed training.')
init_dist(args.launcher, **cfg.dist_params) init_dist(args.launcher, **cfg.dist_params)
if torch.distributed.get_rank() != 0:
logger.setLevel('ERROR')
logger.info('Enabled distributed training.')
# prepare data loaders # prepare data loaders
train_dataset = obj_from_dict(cfg.data.train, datasets) train_dataset = obj_from_dict(cfg.data.train, datasets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment