Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
afe5ce0a
Unverified
Commit
afe5ce0a
authored
Sep 25, 2018
by
Kai Chen
Committed by
GitHub
Sep 25, 2018
Browse files
Merge pull request #1 from OceanPang/dev
faster-rcnn & mask-rcnn train and test support
parents
0401cccd
782ba019
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
14 deletions
+63
-14
mmdet/models/mask_heads/fcn_mask_head.py
mmdet/models/mask_heads/fcn_mask_head.py
+7
-5
mmdet/models/roi_extractors/single_level.py
mmdet/models/roi_extractors/single_level.py
+31
-0
tools/configs/r50_fpn_frcnn_1x.py
tools/configs/r50_fpn_frcnn_1x.py
+12
-4
tools/configs/r50_fpn_maskrcnn_1x.py
tools/configs/r50_fpn_maskrcnn_1x.py
+13
-5
No files found.
mmdet/models/mask_heads/fcn_mask_head.py
View file @
afe5ce0a
...
@@ -93,11 +93,13 @@ class FCNMaskHead(nn.Module):
...
@@ -93,11 +93,13 @@ class FCNMaskHead(nn.Module):
return
mask_targets
return
mask_targets
def
loss
(
self
,
mask_pred
,
mask_targets
,
labels
):
def
loss
(
self
,
mask_pred
,
mask_targets
,
labels
):
loss
=
dict
()
loss_mask
=
mask_cross_entropy
(
mask_pred
,
mask_targets
,
labels
)
loss_mask
=
mask_cross_entropy
(
mask_pred
,
mask_targets
,
labels
)
return
loss_mask
loss
[
'loss_mask'
]
=
loss_mask
return
loss
def
get_seg_masks
(
self
,
mask_pred
,
det_bboxes
,
det_labels
,
rcnn_test_cfg
,
def
get_seg_masks
(
self
,
mask_pred
,
det_bboxes
,
det_labels
,
rcnn_test_cfg
,
ori_s
cal
e
):
ori_s
hap
e
):
"""Get segmentation masks from mask_pred and bboxes
"""Get segmentation masks from mask_pred and bboxes
Args:
Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
...
@@ -108,7 +110,7 @@ class FCNMaskHead(nn.Module):
...
@@ -108,7 +110,7 @@ class FCNMaskHead(nn.Module):
det_labels (Tensor): shape (n, )
det_labels (Tensor): shape (n, )
img_shape (Tensor): shape (3, )
img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config
rcnn_test_cfg (dict): rcnn testing config
rescale (bool): whether rescale masks to
original image size
ori_shape:
original image size
Returns:
Returns:
list[list]: encoded masks
list[list]: encoded masks
"""
"""
...
@@ -118,8 +120,8 @@ class FCNMaskHead(nn.Module):
...
@@ -118,8 +120,8 @@ class FCNMaskHead(nn.Module):
cls_segms
=
[[]
for
_
in
range
(
self
.
num_classes
-
1
)]
cls_segms
=
[[]
for
_
in
range
(
self
.
num_classes
-
1
)]
bboxes
=
det_bboxes
.
cpu
().
numpy
()[:,
:
4
]
bboxes
=
det_bboxes
.
cpu
().
numpy
()[:,
:
4
]
labels
=
det_labels
.
cpu
().
numpy
()
+
1
labels
=
det_labels
.
cpu
().
numpy
()
+
1
img_h
=
ori_s
cal
e
[
0
]
img_h
=
ori_s
hap
e
[
0
]
img_w
=
ori_s
cal
e
[
1
]
img_w
=
ori_s
hap
e
[
1
]
for
i
in
range
(
bboxes
.
shape
[
0
]):
for
i
in
range
(
bboxes
.
shape
[
0
]):
bbox
=
bboxes
[
i
,
:].
astype
(
int
)
bbox
=
bboxes
[
i
,
:].
astype
(
int
)
...
...
mmdet/models/roi_extractors/single_level.py
View file @
afe5ce0a
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmdet
import
ops
from
mmdet
import
ops
from
mmdet.core
import
bbox_assign
,
bbox_sampling
class
SingleLevelRoI
(
nn
.
Module
):
class
SingleLevelRoI
(
nn
.
Module
):
...
@@ -51,6 +52,36 @@ class SingleLevelRoI(nn.Module):
...
@@ -51,6 +52,36 @@ class SingleLevelRoI(nn.Module):
target_lvls
=
target_lvls
.
clamp
(
min
=
0
,
max
=
num_levels
-
1
).
long
()
target_lvls
=
target_lvls
.
clamp
(
min
=
0
,
max
=
num_levels
-
1
).
long
()
return
target_lvls
return
target_lvls
def
sample_proposals
(
self
,
proposals
,
gt_bboxes
,
gt_crowds
,
gt_labels
,
cfg
):
proposals
=
proposals
[:,
:
4
]
assigned_gt_inds
,
assigned_labels
,
argmax_overlaps
,
max_overlaps
=
\
bbox_assign
(
proposals
,
gt_bboxes
,
gt_crowds
,
gt_labels
,
cfg
.
pos_iou_thr
,
cfg
.
neg_iou_thr
,
cfg
.
pos_iou_thr
,
cfg
.
crowd_thr
)
if
cfg
.
add_gt_as_proposals
:
proposals
=
torch
.
cat
([
gt_bboxes
,
proposals
],
dim
=
0
)
gt_assign_self
=
torch
.
arange
(
1
,
len
(
gt_labels
)
+
1
,
dtype
=
torch
.
long
,
device
=
proposals
.
device
)
assigned_gt_inds
=
torch
.
cat
([
gt_assign_self
,
assigned_gt_inds
])
assigned_labels
=
torch
.
cat
([
gt_labels
,
assigned_labels
])
pos_inds
,
neg_inds
=
bbox_sampling
(
assigned_gt_inds
,
cfg
.
roi_batch_size
,
cfg
.
pos_fraction
,
cfg
.
neg_pos_ub
,
cfg
.
pos_balance_sampling
,
max_overlaps
,
cfg
.
neg_balance_thr
)
pos_proposals
=
proposals
[
pos_inds
]
neg_proposals
=
proposals
[
neg_inds
]
pos_assigned_gt_inds
=
assigned_gt_inds
[
pos_inds
]
-
1
pos_gt_bboxes
=
gt_bboxes
[
pos_assigned_gt_inds
,
:]
pos_gt_labels
=
assigned_labels
[
pos_inds
]
return
(
pos_proposals
,
neg_proposals
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
pos_gt_labels
)
def
forward
(
self
,
feats
,
rois
):
def
forward
(
self
,
feats
,
rois
):
"""Extract roi features with the roi layer. If multiple feature levels
"""Extract roi features with the roi layer. If multiple feature levels
are used, then rois are mapped to corresponding levels according to
are used, then rois are mapped to corresponding levels according to
...
...
tools/configs/r50_fpn_frcnn_1x.py
View file @
afe5ce0a
...
@@ -90,7 +90,11 @@ data = dict(
...
@@ -90,7 +90,11 @@ data = dict(
img_scale
=
(
1333
,
800
),
img_scale
=
(
1333
,
800
),
img_norm_cfg
=
img_norm_cfg
,
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
,
size_divisor
=
32
,
flip_ratio
=
0.5
),
flip_ratio
=
0.5
,
with_mask
=
False
,
with_crowd
=
True
,
with_label
=
True
,
test_mode
=
False
),
test
=
dict
(
test
=
dict
(
type
=
dataset_type
,
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
...
@@ -98,7 +102,10 @@ data = dict(
...
@@ -98,7 +102,10 @@ data = dict(
img_scale
=
(
1333
,
800
),
img_scale
=
(
1333
,
800
),
flip_ratio
=
0
,
flip_ratio
=
0
,
img_norm_cfg
=
img_norm_cfg
,
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
))
size_divisor
=
32
,
with_mask
=
False
,
with_label
=
False
,
test_mode
=
True
))
# optimizer
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.02
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.02
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
...
@@ -112,7 +119,7 @@ lr_config = dict(
...
@@ -112,7 +119,7 @@ lr_config = dict(
checkpoint_config
=
dict
(
interval
=
1
)
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
# yapf:disable
log_config
=
dict
(
log_config
=
dict
(
interval
=
5
0
,
interval
=
2
0
,
hooks
=
[
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log')
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log')
...
@@ -120,7 +127,8 @@ log_config = dict(
...
@@ -120,7 +127,8 @@ log_config = dict(
# yapf:enable
# yapf:enable
# runtime settings
# runtime settings
total_epochs
=
12
total_epochs
=
12
dist_params
=
dict
(
backend
=
'nccl'
)
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
,
port
=
'29500'
)
log_level
=
'INFO'
log_level
=
'INFO'
work_dir
=
'./work_dirs/fpn_faster_rcnn_r50_1x'
work_dir
=
'./work_dirs/fpn_faster_rcnn_r50_1x'
load_from
=
None
load_from
=
None
...
...
tools/configs/r50_fpn_maskrcnn_1x.py
View file @
afe5ce0a
...
@@ -103,7 +103,11 @@ data = dict(
...
@@ -103,7 +103,11 @@ data = dict(
img_scale
=
(
1333
,
800
),
img_scale
=
(
1333
,
800
),
img_norm_cfg
=
img_norm_cfg
,
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
,
size_divisor
=
32
,
flip_ratio
=
0.5
),
flip_ratio
=
0.5
,
with_mask
=
True
,
with_crowd
=
True
,
with_label
=
True
,
test_mode
=
False
),
test
=
dict
(
test
=
dict
(
type
=
dataset_type
,
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
...
@@ -111,7 +115,10 @@ data = dict(
...
@@ -111,7 +115,10 @@ data = dict(
img_scale
=
(
1333
,
800
),
img_scale
=
(
1333
,
800
),
flip_ratio
=
0
,
flip_ratio
=
0
,
img_norm_cfg
=
img_norm_cfg
,
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
))
size_divisor
=
32
,
with_mask
=
False
,
with_label
=
False
,
test_mode
=
True
))
# optimizer
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.02
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.02
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
...
@@ -120,12 +127,12 @@ lr_config = dict(
...
@@ -120,12 +127,12 @@ lr_config = dict(
policy
=
'step'
,
policy
=
'step'
,
warmup
=
'linear'
,
warmup
=
'linear'
,
warmup_iters
=
500
,
warmup_iters
=
500
,
warmup_ratio
=
0.33
3
,
warmup_ratio
=
1.0
/
3
,
step
=
[
8
,
11
])
step
=
[
8
,
11
])
checkpoint_config
=
dict
(
interval
=
1
)
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
# yapf:disable
log_config
=
dict
(
log_config
=
dict
(
interval
=
5
0
,
interval
=
2
0
,
hooks
=
[
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
dict
(
type
=
'TextLoggerHook'
),
# ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')),
# ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')),
...
@@ -133,7 +140,8 @@ log_config = dict(
...
@@ -133,7 +140,8 @@ log_config = dict(
# yapf:enable
# yapf:enable
# runtime settings
# runtime settings
total_epochs
=
12
total_epochs
=
12
dist_params
=
dict
(
backend
=
'nccl'
)
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
,
port
=
'29500'
)
log_level
=
'INFO'
log_level
=
'INFO'
work_dir
=
'./work_dirs/fpn_mask_rcnn_r50_1x'
work_dir
=
'./work_dirs/fpn_mask_rcnn_r50_1x'
load_from
=
None
load_from
=
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment