Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
d0fb2a8d
Commit
d0fb2a8d
authored
Sep 26, 2018
by
Kai Chen
Browse files
suppress logging for processes whose rank > 0
parent
5421859a
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
46 additions
and
29 deletions
+46
-29
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+4
-1
mmdet/models/detectors/base.py
mmdet/models/detectors/base.py
+6
-4
mmdet/models/detectors/rpn.py
mmdet/models/detectors/rpn.py
+1
-2
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+15
-17
tools/configs/r50_fpn_rpn_1x.py
tools/configs/r50_fpn_rpn_1x.py
+1
-1
tools/dist_train.sh
tools/dist_train.sh
+1
-1
tools/train.py
tools/train.py
+18
-3
No files found.
mmdet/models/backbones/resnet.py
View file @
d0fb2a8d
import
logging
import
math
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.torchpack
import
load_checkpoint
...
...
@@ -241,7 +243,8 @@ class ResNet(nn.Module):
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
load_checkpoint
(
self
,
pretrained
,
strict
=
False
)
logger
=
logging
.
getLogger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
...
...
mmdet/models/detectors/base.py
View file @
d0fb2a8d
import
logging
from
abc
import
ABCMeta
,
abstractmethod
import
torch
...
...
@@ -12,10 +13,6 @@ class BaseDetector(nn.Module):
def
__init__
(
self
):
super
(
BaseDetector
,
self
).
__init__
()
@
abstractmethod
def
init_weights
(
self
):
pass
@
abstractmethod
def
extract_feat
(
self
,
imgs
):
pass
...
...
@@ -39,6 +36,11 @@ class BaseDetector(nn.Module):
def
aug_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
pass
def
init_weights
(
self
,
pretrained
=
None
):
if
pretrained
is
not
None
:
logger
=
logging
.
getLogger
()
logger
.
info
(
'load model from: {}'
.
format
(
pretrained
))
def
forward_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
for
var
,
name
in
[(
imgs
,
'imgs'
),
(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
...
...
mmdet/models/detectors/rpn.py
View file @
d0fb2a8d
...
...
@@ -24,8 +24,7 @@ class RPN(BaseDetector, RPNTestMixin):
self
.
init_weights
(
pretrained
=
pretrained
)
def
init_weights
(
self
,
pretrained
=
None
):
if
pretrained
is
not
None
:
print
(
'load model from: {}'
.
format
(
pretrained
))
super
(
RPN
,
self
).
init_weights
(
pretrained
)
self
.
backbone
.
init_weights
(
pretrained
=
pretrained
)
if
self
.
neck
is
not
None
:
self
.
neck
.
init_weights
()
...
...
mmdet/models/detectors/two_stage.py
View file @
d0fb2a8d
...
...
@@ -24,10 +24,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
super
(
TwoStageDetector
,
self
).
__init__
()
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
with_neck
=
True
if
neck
is
not
None
else
False
assert
self
.
with_neck
,
"TwoStageDetector must be implemented with FPN now."
if
self
.
with_neck
:
if
neck
is
not
None
:
self
.
with_neck
=
True
self
.
neck
=
builder
.
build_neck
(
neck
)
else
:
raise
NotImplementedError
self
.
with_rpn
=
True
if
rpn_head
is
not
None
else
False
if
self
.
with_rpn
:
...
...
@@ -51,8 +52,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
self
.
init_weights
(
pretrained
=
pretrained
)
def
init_weights
(
self
,
pretrained
=
None
):
if
pretrained
is
not
None
:
print
(
'load model from: {}'
.
format
(
pretrained
))
super
(
TwoStageDetector
,
self
).
init_weights
(
pretrained
)
self
.
backbone
.
init_weights
(
pretrained
=
pretrained
)
if
self
.
with_neck
:
if
isinstance
(
self
.
neck
,
nn
.
Sequential
):
...
...
@@ -104,9 +104,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
pos_gt_labels
)
=
multi_apply
(
self
.
bbox_roi_extractor
.
sample_proposals
,
proposal_list
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
rcnn_train_cfg_list
)
labels
,
label_weights
,
bbox_targets
,
bbox_weights
=
\
self
.
bbox_head
.
get_bbox_target
(
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_labels
,
self
.
train_cfg
.
rcnn
)
(
labels
,
label_weights
,
bbox_targets
,
bbox_weights
)
=
self
.
bbox_head
.
get_bbox_target
(
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_labels
,
self
.
train_cfg
.
rcnn
)
rois
=
bbox2roi
([
torch
.
cat
([
pos
,
neg
],
dim
=
0
)
...
...
@@ -139,7 +140,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
def
simple_test
(
self
,
img
,
img_meta
,
proposals
=
None
,
rescale
=
False
):
"""Test without augmentation."""
assert
proposals
==
None
,
"Fast RCNN hasn't been implemented."
assert
proposals
is
None
,
"Fast RCNN hasn't been implemented."
assert
self
.
with_bbox
,
"Bbox head must be implemented."
x
=
self
.
extract_feat
(
img
)
...
...
@@ -152,12 +153,12 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
bbox_results
=
bbox2result
(
det_bboxes
,
det_labels
,
self
.
bbox_head
.
num_classes
)
if
self
.
with_mask
:
if
not
self
.
with_mask
:
return
bbox_results
else
:
segm_results
=
self
.
simple_test_mask
(
x
,
img_meta
,
det_bboxes
,
det_labels
,
rescale
=
rescale
)
return
bbox_results
,
segm_results
else
:
return
bbox_results
def
aug_test
(
self
,
imgs
,
img_metas
,
rescale
=
False
):
"""Test with augmentations.
...
...
@@ -165,7 +166,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
# recompute
self.extract_feats(imgs) because of 'yield' and
memory
# recompute
feats to save
memory
proposal_list
=
self
.
aug_test_rpn
(
self
.
extract_feats
(
imgs
),
img_metas
,
self
.
test_cfg
.
rpn
)
det_bboxes
,
det_labels
=
self
.
aug_test_bboxes
(
...
...
@@ -183,10 +184,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# det_bboxes always keep the original scale
if
self
.
with_mask
:
segm_results
=
self
.
aug_test_mask
(
self
.
extract_feats
(
imgs
),
img_metas
,
det_bboxes
,
det_labels
)
self
.
extract_feats
(
imgs
),
img_metas
,
det_bboxes
,
det_labels
)
return
bbox_results
,
segm_results
else
:
return
bbox_results
tools/configs/r50_fpn_rpn_1x.py
View file @
d0fb2a8d
...
...
@@ -114,4 +114,4 @@ log_level = 'INFO'
work_dir
=
'./work_dirs/fpn_rpn_r50_1x'
load_from
=
None
resume_from
=
None
workflow
=
[(
'train'
,
1
),
(
'val'
,
1
)]
workflow
=
[(
'train'
,
1
)]
tools/dist_train.sh
View file @
d0fb2a8d
...
...
@@ -2,4 +2,4 @@
PYTHON
=
${
PYTHON
:-
"python"
}
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$2
train.py
$1
--launcher
pytorch
$
3
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$2
train.py
$1
--launcher
pytorch
$
{
@
:3
}
tools/train.py
View file @
d0fb2a8d
from
__future__
import
division
import
argparse
import
logging
from
collections
import
OrderedDict
import
torch
...
...
@@ -45,9 +46,17 @@ def batch_processor(model, data, train_mode):
return
outputs
def
get_logger
(
log_level
):
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
log_level
)
logger
=
logging
.
getLogger
()
return
logger
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--work_dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--validate'
,
action
=
'store_true'
,
...
...
@@ -69,16 +78,22 @@ def main():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
.
update
(
gpus
=
args
.
gpus
)
if
args
.
work_dir
is
not
None
:
cfg
.
work_dir
=
args
.
work_dir
cfg
.
gpus
=
args
.
gpus
logger
=
get_logger
(
cfg
.
log_level
)
# init distributed environment if necessary
if
args
.
launcher
==
'none'
:
dist
=
False
print
(
'Disabled distributed training.'
)
logger
.
info
(
'Disabled distributed training.'
)
else
:
dist
=
True
print
(
'Enabled distributed training.'
)
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
if
torch
.
distributed
.
get_rank
()
!=
0
:
logger
.
setLevel
(
'ERROR'
)
logger
.
info
(
'Enabled distributed training.'
)
# prepare data loaders
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment