Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
d0fb2a8d
Commit
d0fb2a8d
authored
Sep 26, 2018
by
Kai Chen
Browse files
suppress logging for processes whose rank > 0
parent
5421859a
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
46 additions
and
29 deletions
+46
-29
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+4
-1
mmdet/models/detectors/base.py
mmdet/models/detectors/base.py
+6
-4
mmdet/models/detectors/rpn.py
mmdet/models/detectors/rpn.py
+1
-2
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+15
-17
tools/configs/r50_fpn_rpn_1x.py
tools/configs/r50_fpn_rpn_1x.py
+1
-1
tools/dist_train.sh
tools/dist_train.sh
+1
-1
tools/train.py
tools/train.py
+18
-3
No files found.
mmdet/models/backbones/resnet.py
View file @
d0fb2a8d
import
logging
import
math
import
math
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
import
torch.utils.checkpoint
as
cp
from
mmcv.torchpack
import
load_checkpoint
from
mmcv.torchpack
import
load_checkpoint
...
@@ -241,7 +243,8 @@ class ResNet(nn.Module):
...
@@ -241,7 +243,8 @@ class ResNet(nn.Module):
def
init_weights
(
self
,
pretrained
=
None
):
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
if
isinstance
(
pretrained
,
str
):
load_checkpoint
(
self
,
pretrained
,
strict
=
False
)
logger
=
logging
.
getLogger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
isinstance
(
m
,
nn
.
Conv2d
):
...
...
mmdet/models/detectors/base.py
View file @
d0fb2a8d
import
logging
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
torch
import
torch
...
@@ -12,10 +13,6 @@ class BaseDetector(nn.Module):
...
@@ -12,10 +13,6 @@ class BaseDetector(nn.Module):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
BaseDetector
,
self
).
__init__
()
super
(
BaseDetector
,
self
).
__init__
()
@
abstractmethod
def
init_weights
(
self
):
pass
@
abstractmethod
@
abstractmethod
def
extract_feat
(
self
,
imgs
):
def
extract_feat
(
self
,
imgs
):
pass
pass
...
@@ -39,6 +36,11 @@ class BaseDetector(nn.Module):
...
@@ -39,6 +36,11 @@ class BaseDetector(nn.Module):
def
aug_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
def
aug_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
pass
pass
def
init_weights
(
self
,
pretrained
=
None
):
if
pretrained
is
not
None
:
logger
=
logging
.
getLogger
()
logger
.
info
(
'load model from: {}'
.
format
(
pretrained
))
def
forward_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
def
forward_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
for
var
,
name
in
[(
imgs
,
'imgs'
),
(
img_metas
,
'img_metas'
)]:
for
var
,
name
in
[(
imgs
,
'imgs'
),
(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
if
not
isinstance
(
var
,
list
):
...
...
mmdet/models/detectors/rpn.py
View file @
d0fb2a8d
...
@@ -24,8 +24,7 @@ class RPN(BaseDetector, RPNTestMixin):
...
@@ -24,8 +24,7 @@ class RPN(BaseDetector, RPNTestMixin):
self
.
init_weights
(
pretrained
=
pretrained
)
self
.
init_weights
(
pretrained
=
pretrained
)
def
init_weights
(
self
,
pretrained
=
None
):
def
init_weights
(
self
,
pretrained
=
None
):
if
pretrained
is
not
None
:
super
(
RPN
,
self
).
init_weights
(
pretrained
)
print
(
'load model from: {}'
.
format
(
pretrained
))
self
.
backbone
.
init_weights
(
pretrained
=
pretrained
)
self
.
backbone
.
init_weights
(
pretrained
=
pretrained
)
if
self
.
neck
is
not
None
:
if
self
.
neck
is
not
None
:
self
.
neck
.
init_weights
()
self
.
neck
.
init_weights
()
...
...
mmdet/models/detectors/two_stage.py
View file @
d0fb2a8d
...
@@ -24,10 +24,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -24,10 +24,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
super
(
TwoStageDetector
,
self
).
__init__
()
super
(
TwoStageDetector
,
self
).
__init__
()
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
with_neck
=
True
if
neck
is
not
None
else
False
if
neck
is
not
None
:
assert
self
.
with_neck
,
"TwoStageDetector must be implemented with FPN now."
self
.
with_neck
=
True
if
self
.
with_neck
:
self
.
neck
=
builder
.
build_neck
(
neck
)
self
.
neck
=
builder
.
build_neck
(
neck
)
else
:
raise
NotImplementedError
self
.
with_rpn
=
True
if
rpn_head
is
not
None
else
False
self
.
with_rpn
=
True
if
rpn_head
is
not
None
else
False
if
self
.
with_rpn
:
if
self
.
with_rpn
:
...
@@ -51,8 +52,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -51,8 +52,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
self
.
init_weights
(
pretrained
=
pretrained
)
self
.
init_weights
(
pretrained
=
pretrained
)
def
init_weights
(
self
,
pretrained
=
None
):
def
init_weights
(
self
,
pretrained
=
None
):
if
pretrained
is
not
None
:
super
(
TwoStageDetector
,
self
).
init_weights
(
pretrained
)
print
(
'load model from: {}'
.
format
(
pretrained
))
self
.
backbone
.
init_weights
(
pretrained
=
pretrained
)
self
.
backbone
.
init_weights
(
pretrained
=
pretrained
)
if
self
.
with_neck
:
if
self
.
with_neck
:
if
isinstance
(
self
.
neck
,
nn
.
Sequential
):
if
isinstance
(
self
.
neck
,
nn
.
Sequential
):
...
@@ -104,9 +104,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -104,9 +104,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
pos_gt_labels
)
=
multi_apply
(
pos_gt_labels
)
=
multi_apply
(
self
.
bbox_roi_extractor
.
sample_proposals
,
proposal_list
,
self
.
bbox_roi_extractor
.
sample_proposals
,
proposal_list
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
rcnn_train_cfg_list
)
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
rcnn_train_cfg_list
)
labels
,
label_weights
,
bbox_targets
,
bbox_weights
=
\
(
labels
,
label_weights
,
bbox_targets
,
self
.
bbox_head
.
get_bbox_target
(
pos_proposals
,
neg_proposals
,
bbox_weights
)
=
self
.
bbox_head
.
get_bbox_target
(
pos_gt_bboxes
,
pos_gt_labels
,
self
.
train_cfg
.
rcnn
)
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_labels
,
self
.
train_cfg
.
rcnn
)
rois
=
bbox2roi
([
rois
=
bbox2roi
([
torch
.
cat
([
pos
,
neg
],
dim
=
0
)
torch
.
cat
([
pos
,
neg
],
dim
=
0
)
...
@@ -139,7 +140,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -139,7 +140,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
def
simple_test
(
self
,
img
,
img_meta
,
proposals
=
None
,
rescale
=
False
):
def
simple_test
(
self
,
img
,
img_meta
,
proposals
=
None
,
rescale
=
False
):
"""Test without augmentation."""
"""Test without augmentation."""
assert
proposals
==
None
,
"Fast RCNN hasn't been implemented."
assert
proposals
is
None
,
"Fast RCNN hasn't been implemented."
assert
self
.
with_bbox
,
"Bbox head must be implemented."
assert
self
.
with_bbox
,
"Bbox head must be implemented."
x
=
self
.
extract_feat
(
img
)
x
=
self
.
extract_feat
(
img
)
...
@@ -152,12 +153,12 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -152,12 +153,12 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
bbox_results
=
bbox2result
(
det_bboxes
,
det_labels
,
bbox_results
=
bbox2result
(
det_bboxes
,
det_labels
,
self
.
bbox_head
.
num_classes
)
self
.
bbox_head
.
num_classes
)
if
self
.
with_mask
:
if
not
self
.
with_mask
:
return
bbox_results
else
:
segm_results
=
self
.
simple_test_mask
(
segm_results
=
self
.
simple_test_mask
(
x
,
img_meta
,
det_bboxes
,
det_labels
,
rescale
=
rescale
)
x
,
img_meta
,
det_bboxes
,
det_labels
,
rescale
=
rescale
)
return
bbox_results
,
segm_results
return
bbox_results
,
segm_results
else
:
return
bbox_results
def
aug_test
(
self
,
imgs
,
img_metas
,
rescale
=
False
):
def
aug_test
(
self
,
imgs
,
img_metas
,
rescale
=
False
):
"""Test with augmentations.
"""Test with augmentations.
...
@@ -165,7 +166,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -165,7 +166,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
If rescale is False, then returned bboxes and masks will fit the scale
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
of imgs[0].
"""
"""
# recompute
self.extract_feats(imgs) because of 'yield' and
memory
# recompute
feats to save
memory
proposal_list
=
self
.
aug_test_rpn
(
proposal_list
=
self
.
aug_test_rpn
(
self
.
extract_feats
(
imgs
),
img_metas
,
self
.
test_cfg
.
rpn
)
self
.
extract_feats
(
imgs
),
img_metas
,
self
.
test_cfg
.
rpn
)
det_bboxes
,
det_labels
=
self
.
aug_test_bboxes
(
det_bboxes
,
det_labels
=
self
.
aug_test_bboxes
(
...
@@ -183,10 +184,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -183,10 +184,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# det_bboxes always keep the original scale
# det_bboxes always keep the original scale
if
self
.
with_mask
:
if
self
.
with_mask
:
segm_results
=
self
.
aug_test_mask
(
segm_results
=
self
.
aug_test_mask
(
self
.
extract_feats
(
imgs
),
self
.
extract_feats
(
imgs
),
img_metas
,
det_bboxes
,
det_labels
)
img_metas
,
det_bboxes
,
det_labels
)
return
bbox_results
,
segm_results
return
bbox_results
,
segm_results
else
:
else
:
return
bbox_results
return
bbox_results
tools/configs/r50_fpn_rpn_1x.py
View file @
d0fb2a8d
...
@@ -114,4 +114,4 @@ log_level = 'INFO'
...
@@ -114,4 +114,4 @@ log_level = 'INFO'
work_dir
=
'./work_dirs/fpn_rpn_r50_1x'
work_dir
=
'./work_dirs/fpn_rpn_r50_1x'
load_from
=
None
load_from
=
None
resume_from
=
None
resume_from
=
None
workflow
=
[(
'train'
,
1
),
(
'val'
,
1
)]
workflow
=
[(
'train'
,
1
)]
tools/dist_train.sh
View file @
d0fb2a8d
...
@@ -2,4 +2,4 @@
...
@@ -2,4 +2,4 @@
PYTHON
=
${
PYTHON
:-
"python"
}
PYTHON
=
${
PYTHON
:-
"python"
}
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$2
train.py
$1
--launcher
pytorch
$
3
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$2
train.py
$1
--launcher
pytorch
$
{
@
:3
}
tools/train.py
View file @
d0fb2a8d
from
__future__
import
division
from
__future__
import
division
import
argparse
import
argparse
import
logging
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
torch
import
torch
...
@@ -45,9 +46,17 @@ def batch_processor(model, data, train_mode):
...
@@ -45,9 +46,17 @@ def batch_processor(model, data, train_mode):
return
outputs
return
outputs
def
get_logger
(
log_level
):
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
log_level
)
logger
=
logging
.
getLogger
()
return
logger
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--work_dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--validate'
,
'--validate'
,
action
=
'store_true'
,
action
=
'store_true'
,
...
@@ -69,16 +78,22 @@ def main():
...
@@ -69,16 +78,22 @@ def main():
args
=
parse_args
()
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
.
update
(
gpus
=
args
.
gpus
)
if
args
.
work_dir
is
not
None
:
cfg
.
work_dir
=
args
.
work_dir
cfg
.
gpus
=
args
.
gpus
logger
=
get_logger
(
cfg
.
log_level
)
# init distributed environment if necessary
# init distributed environment if necessary
if
args
.
launcher
==
'none'
:
if
args
.
launcher
==
'none'
:
dist
=
False
dist
=
False
print
(
'Disabled distributed training.'
)
logger
.
info
(
'Disabled distributed training.'
)
else
:
else
:
dist
=
True
dist
=
True
print
(
'Enabled distributed training.'
)
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
if
torch
.
distributed
.
get_rank
()
!=
0
:
logger
.
setLevel
(
'ERROR'
)
logger
.
info
(
'Enabled distributed training.'
)
# prepare data loaders
# prepare data loaders
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment