Commit d0fb2a8d authored by Kai Chen's avatar Kai Chen
Browse files

suppress logging for processes whose rank > 0

parent 5421859a
import logging
import math
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.torchpack import load_checkpoint
......@@ -241,7 +243,8 @@ class ResNet(nn.Module):
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
load_checkpoint(self, pretrained, strict=False)
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
......
import logging
from abc import ABCMeta, abstractmethod
import torch
......@@ -12,10 +13,6 @@ class BaseDetector(nn.Module):
def __init__(self):
super(BaseDetector, self).__init__()
@abstractmethod
def init_weights(self):
pass
@abstractmethod
def extract_feat(self, imgs):
pass
......@@ -39,6 +36,11 @@ class BaseDetector(nn.Module):
def aug_test(self, imgs, img_metas, **kwargs):
pass
def init_weights(self, pretrained=None):
if pretrained is not None:
logger = logging.getLogger()
logger.info('load model from: {}'.format(pretrained))
def forward_test(self, imgs, img_metas, **kwargs):
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
......
......@@ -24,8 +24,7 @@ class RPN(BaseDetector, RPNTestMixin):
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
if pretrained is not None:
print('load model from: {}'.format(pretrained))
super(RPN, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.neck is not None:
self.neck.init_weights()
......
......@@ -24,10 +24,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
super(TwoStageDetector, self).__init__()
self.backbone = builder.build_backbone(backbone)
self.with_neck = True if neck is not None else False
assert self.with_neck, "TwoStageDetector must be implemented with FPN now."
if self.with_neck:
if neck is not None:
self.with_neck = True
self.neck = builder.build_neck(neck)
else:
raise NotImplementedError
self.with_rpn = True if rpn_head is not None else False
if self.with_rpn:
......@@ -51,8 +52,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
if pretrained is not None:
print('load model from: {}'.format(pretrained))
super(TwoStageDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
......@@ -104,9 +104,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
pos_gt_labels) = multi_apply(
self.bbox_roi_extractor.sample_proposals, proposal_list,
gt_bboxes, gt_bboxes_ignore, gt_labels, rcnn_train_cfg_list)
labels, label_weights, bbox_targets, bbox_weights = \
self.bbox_head.get_bbox_target(pos_proposals, neg_proposals,
pos_gt_bboxes, pos_gt_labels, self.train_cfg.rcnn)
(labels, label_weights, bbox_targets,
bbox_weights) = self.bbox_head.get_bbox_target(
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
self.train_cfg.rcnn)
rois = bbox2roi([
torch.cat([pos, neg], dim=0)
......@@ -139,7 +140,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
def simple_test(self, img, img_meta, proposals=None, rescale=False):
"""Test without augmentation."""
assert proposals == None, "Fast RCNN hasn't been implemented."
assert proposals is None, "Fast RCNN hasn't been implemented."
assert self.with_bbox, "Bbox head must be implemented."
x = self.extract_feat(img)
......@@ -152,12 +153,12 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
bbox_results = bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes)
if self.with_mask:
if not self.with_mask:
return bbox_results
else:
segm_results = self.simple_test_mask(
x, img_meta, det_bboxes, det_labels, rescale=rescale)
return bbox_results, segm_results
else:
return bbox_results
def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations.
......@@ -165,7 +166,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
# recompute self.extract_feats(imgs) because of 'yield' and memory
# recompute feats to save memory
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
det_bboxes, det_labels = self.aug_test_bboxes(
......@@ -183,10 +184,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# det_bboxes always keep the original scale
if self.with_mask:
segm_results = self.aug_test_mask(
self.extract_feats(imgs),
img_metas,
det_bboxes,
det_labels)
self.extract_feats(imgs), img_metas, det_bboxes, det_labels)
return bbox_results, segm_results
else:
return bbox_results
......@@ -114,4 +114,4 @@ log_level = 'INFO'
work_dir = './work_dirs/fpn_rpn_r50_1x'
load_from = None
resume_from = None
workflow = [('train', 1), ('val', 1)]
workflow = [('train', 1)]
......@@ -2,4 +2,4 @@
PYTHON=${PYTHON:-"python"}
$PYTHON -m torch.distributed.launch --nproc_per_node=$2 train.py $1 --launcher pytorch $3
$PYTHON -m torch.distributed.launch --nproc_per_node=$2 train.py $1 --launcher pytorch ${@:3}
from __future__ import division
import argparse
import logging
from collections import OrderedDict
import torch
......@@ -45,9 +46,17 @@ def batch_processor(model, data, train_mode):
return outputs
def get_logger(log_level):
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
logger = logging.getLogger()
return logger
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work_dir', help='the dir to save logs and models')
parser.add_argument(
'--validate',
action='store_true',
......@@ -69,16 +78,22 @@ def main():
args = parse_args()
cfg = Config.fromfile(args.config)
cfg.update(gpus=args.gpus)
if args.work_dir is not None:
cfg.work_dir = args.work_dir
cfg.gpus = args.gpus
logger = get_logger(cfg.log_level)
# init distributed environment if necessary
if args.launcher == 'none':
dist = False
print('Disabled distributed training.')
logger.info('Disabled distributed training.')
else:
dist = True
print('Enabled distributed training.')
init_dist(args.launcher, **cfg.dist_params)
if torch.distributed.get_rank() != 0:
logger.setLevel('ERROR')
logger.info('Enabled distributed training.')
# prepare data loaders
train_dataset = obj_from_dict(cfg.data.train, datasets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment