Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
afe5ce0a
Unverified
Commit
afe5ce0a
authored
Sep 25, 2018
by
Kai Chen
Committed by
GitHub
Sep 25, 2018
Browse files
Merge pull request #1 from OceanPang/dev
faster-rcnn & mask-rcnn train and test support
parents
0401cccd
782ba019
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
14 deletions
+63
-14
mmdet/models/mask_heads/fcn_mask_head.py
mmdet/models/mask_heads/fcn_mask_head.py
+7
-5
mmdet/models/roi_extractors/single_level.py
mmdet/models/roi_extractors/single_level.py
+31
-0
tools/configs/r50_fpn_frcnn_1x.py
tools/configs/r50_fpn_frcnn_1x.py
+12
-4
tools/configs/r50_fpn_maskrcnn_1x.py
tools/configs/r50_fpn_maskrcnn_1x.py
+13
-5
No files found.
mmdet/models/mask_heads/fcn_mask_head.py
View file @
afe5ce0a
...
...
@@ -93,11 +93,13 @@ class FCNMaskHead(nn.Module):
return
mask_targets
def
loss
(
self
,
mask_pred
,
mask_targets
,
labels
):
loss
=
dict
()
loss_mask
=
mask_cross_entropy
(
mask_pred
,
mask_targets
,
labels
)
return
loss_mask
loss
[
'loss_mask'
]
=
loss_mask
return
loss
def
get_seg_masks
(
self
,
mask_pred
,
det_bboxes
,
det_labels
,
rcnn_test_cfg
,
ori_s
cal
e
):
ori_s
hap
e
):
"""Get segmentation masks from mask_pred and bboxes
Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
...
...
@@ -108,7 +110,7 @@ class FCNMaskHead(nn.Module):
det_labels (Tensor): shape (n, )
img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config
rescale (bool): whether rescale masks to
original image size
ori_shape:
original image size
Returns:
list[list]: encoded masks
"""
...
...
@@ -118,8 +120,8 @@ class FCNMaskHead(nn.Module):
cls_segms
=
[[]
for
_
in
range
(
self
.
num_classes
-
1
)]
bboxes
=
det_bboxes
.
cpu
().
numpy
()[:,
:
4
]
labels
=
det_labels
.
cpu
().
numpy
()
+
1
img_h
=
ori_s
cal
e
[
0
]
img_w
=
ori_s
cal
e
[
1
]
img_h
=
ori_s
hap
e
[
0
]
img_w
=
ori_s
hap
e
[
1
]
for
i
in
range
(
bboxes
.
shape
[
0
]):
bbox
=
bboxes
[
i
,
:].
astype
(
int
)
...
...
mmdet/models/roi_extractors/single_level.py
View file @
afe5ce0a
...
...
@@ -4,6 +4,7 @@ import torch
import
torch.nn
as
nn
from
mmdet
import
ops
from
mmdet.core
import
bbox_assign
,
bbox_sampling
class
SingleLevelRoI
(
nn
.
Module
):
...
...
@@ -51,6 +52,36 @@ class SingleLevelRoI(nn.Module):
target_lvls
=
target_lvls
.
clamp
(
min
=
0
,
max
=
num_levels
-
1
).
long
()
return
target_lvls
def
sample_proposals
(
self
,
proposals
,
gt_bboxes
,
gt_crowds
,
gt_labels
,
cfg
):
proposals
=
proposals
[:,
:
4
]
assigned_gt_inds
,
assigned_labels
,
argmax_overlaps
,
max_overlaps
=
\
bbox_assign
(
proposals
,
gt_bboxes
,
gt_crowds
,
gt_labels
,
cfg
.
pos_iou_thr
,
cfg
.
neg_iou_thr
,
cfg
.
pos_iou_thr
,
cfg
.
crowd_thr
)
if
cfg
.
add_gt_as_proposals
:
proposals
=
torch
.
cat
([
gt_bboxes
,
proposals
],
dim
=
0
)
gt_assign_self
=
torch
.
arange
(
1
,
len
(
gt_labels
)
+
1
,
dtype
=
torch
.
long
,
device
=
proposals
.
device
)
assigned_gt_inds
=
torch
.
cat
([
gt_assign_self
,
assigned_gt_inds
])
assigned_labels
=
torch
.
cat
([
gt_labels
,
assigned_labels
])
pos_inds
,
neg_inds
=
bbox_sampling
(
assigned_gt_inds
,
cfg
.
roi_batch_size
,
cfg
.
pos_fraction
,
cfg
.
neg_pos_ub
,
cfg
.
pos_balance_sampling
,
max_overlaps
,
cfg
.
neg_balance_thr
)
pos_proposals
=
proposals
[
pos_inds
]
neg_proposals
=
proposals
[
neg_inds
]
pos_assigned_gt_inds
=
assigned_gt_inds
[
pos_inds
]
-
1
pos_gt_bboxes
=
gt_bboxes
[
pos_assigned_gt_inds
,
:]
pos_gt_labels
=
assigned_labels
[
pos_inds
]
return
(
pos_proposals
,
neg_proposals
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
pos_gt_labels
)
def
forward
(
self
,
feats
,
rois
):
"""Extract roi features with the roi layer. If multiple feature levels
are used, then rois are mapped to corresponding levels according to
...
...
tools/configs/r50_fpn_frcnn_1x.py
View file @
afe5ce0a
...
...
@@ -90,7 +90,11 @@ data = dict(
img_scale
=
(
1333
,
800
),
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
,
flip_ratio
=
0.5
),
flip_ratio
=
0.5
,
with_mask
=
False
,
with_crowd
=
True
,
with_label
=
True
,
test_mode
=
False
),
test
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
...
...
@@ -98,7 +102,10 @@ data = dict(
img_scale
=
(
1333
,
800
),
flip_ratio
=
0
,
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
))
size_divisor
=
32
,
with_mask
=
False
,
with_label
=
False
,
test_mode
=
True
))
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.02
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
...
...
@@ -112,7 +119,7 @@ lr_config = dict(
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
5
0
,
interval
=
2
0
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log')
...
...
@@ -120,7 +127,8 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs
=
12
dist_params
=
dict
(
backend
=
'nccl'
)
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
,
port
=
'29500'
)
log_level
=
'INFO'
work_dir
=
'./work_dirs/fpn_faster_rcnn_r50_1x'
load_from
=
None
...
...
tools/configs/r50_fpn_maskrcnn_1x.py
View file @
afe5ce0a
...
...
@@ -103,7 +103,11 @@ data = dict(
img_scale
=
(
1333
,
800
),
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
,
flip_ratio
=
0.5
),
flip_ratio
=
0.5
,
with_mask
=
True
,
with_crowd
=
True
,
with_label
=
True
,
test_mode
=
False
),
test
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
...
...
@@ -111,7 +115,10 @@ data = dict(
img_scale
=
(
1333
,
800
),
flip_ratio
=
0
,
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
))
size_divisor
=
32
,
with_mask
=
False
,
with_label
=
False
,
test_mode
=
True
))
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.02
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
...
...
@@ -120,12 +127,12 @@ lr_config = dict(
policy
=
'step'
,
warmup
=
'linear'
,
warmup_iters
=
500
,
warmup_ratio
=
0.33
3
,
warmup_ratio
=
1.0
/
3
,
step
=
[
8
,
11
])
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
5
0
,
interval
=
2
0
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')),
...
...
@@ -133,7 +140,8 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs
=
12
dist_params
=
dict
(
backend
=
'nccl'
)
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
,
port
=
'29500'
)
log_level
=
'INFO'
work_dir
=
'./work_dirs/fpn_mask_rcnn_r50_1x'
load_from
=
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment