Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
98b20b9b
Commit
98b20b9b
authored
Oct 05, 2018
by
Kai Chen
Browse files
Merge branch 'dev' into mask-debug
parents
7e3ed283
f0cb1d12
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
146 additions
and
136 deletions
+146
-136
mmdet/core/bbox_ops/__init__.py
mmdet/core/bbox_ops/__init__.py
+5
-6
mmdet/core/bbox_ops/bbox_target.py
mmdet/core/bbox_ops/bbox_target.py
+21
-32
mmdet/core/bbox_ops/sampling.py
mmdet/core/bbox_ops/sampling.py
+29
-23
mmdet/core/bbox_ops/transforms.py
mmdet/core/bbox_ops/transforms.py
+21
-12
mmdet/core/eval/eval_hooks.py
mmdet/core/eval/eval_hooks.py
+0
-1
mmdet/core/eval/mean_ap.py
mmdet/core/eval/mean_ap.py
+27
-25
mmdet/core/losses/losses.py
mmdet/core/losses/losses.py
+2
-2
mmdet/core/rpn_ops/anchor_target.py
mmdet/core/rpn_ops/anchor_target.py
+3
-3
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+18
-12
mmdet/models/bbox_heads/bbox_head.py
mmdet/models/bbox_heads/bbox_head.py
+3
-4
mmdet/models/roi_extractors/single_level.py
mmdet/models/roi_extractors/single_level.py
+11
-9
mmdet/models/rpn_heads/rpn_head.py
mmdet/models/rpn_heads/rpn_head.py
+3
-4
tools/configs/r50_fpn_frcnn_1x.py
tools/configs/r50_fpn_frcnn_1x.py
+1
-1
tools/configs/r50_fpn_maskrcnn_1x.py
tools/configs/r50_fpn_maskrcnn_1x.py
+1
-1
tools/configs/r50_fpn_rpn_1x.py
tools/configs/r50_fpn_rpn_1x.py
+1
-1
No files found.
mmdet/core/bbox_ops/__init__.py
View file @
98b20b9b
from
.geometry
import
bbox_overlaps
from
.geometry
import
bbox_overlaps
from
.sampling
import
(
random_choice
,
bbox_assign
,
bbox_assign_
via
_overlaps
,
from
.sampling
import
(
random_choice
,
bbox_assign
,
bbox_assign_
wrt
_overlaps
,
bbox_sampling
,
sample_positives
,
sample_negatives
)
bbox_sampling
,
sample_positives
,
sample_negatives
)
from
.transforms
import
(
bbox_transform
,
bbox_transform_inv
,
bbox_flip
,
from
.transforms
import
(
bbox2delta
,
delta2bbox
,
bbox_flip
,
bbox_mapping
,
bbox_mapping
,
bbox_mapping_back
,
bbox2roi
,
roi2bbox
,
bbox_mapping_back
,
bbox2roi
,
roi2bbox
,
bbox2result
)
bbox2result
)
from
.bbox_target
import
bbox_target
from
.bbox_target
import
bbox_target
__all__
=
[
__all__
=
[
'bbox_overlaps'
,
'random_choice'
,
'bbox_assign'
,
'bbox_overlaps'
,
'random_choice'
,
'bbox_assign'
,
'bbox_assign_
via
_overlaps'
,
'bbox_sampling'
,
'sample_positives'
,
'bbox_assign_
wrt
_overlaps'
,
'bbox_sampling'
,
'sample_positives'
,
'sample_negatives'
,
'bbox
_transform'
,
'bbox_transform_inv
'
,
'bbox_flip'
,
'sample_negatives'
,
'bbox
2delta'
,
'delta2bbox
'
,
'bbox_flip'
,
'bbox_mapping'
,
'bbox_mapping_back'
,
'bbox2roi'
,
'roi2bbox'
,
'bbox2result'
,
'bbox_mapping'
,
'bbox_mapping_back'
,
'bbox2roi'
,
'roi2bbox'
,
'bbox2result'
,
'bbox_target'
'bbox_target'
]
]
mmdet/core/bbox_ops/bbox_target.py
View file @
98b20b9b
import
mmcv
import
torch
import
torch
from
.
geometry
import
bbox
_overlaps
from
.
transforms
import
bbox
2delta
from
.
transforms
import
bbox_transform
,
bbox_transform_inv
from
.
.utils
import
multi_apply
def
bbox_target
(
pos_proposals_list
,
def
bbox_target
(
pos_proposals_list
,
...
@@ -13,33 +12,23 @@ def bbox_target(pos_proposals_list,
...
@@ -13,33 +12,23 @@ def bbox_target(pos_proposals_list,
reg_num_classes
=
1
,
reg_num_classes
=
1
,
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
],
return_list
=
False
):
concat
=
True
):
img_per_gpu
=
len
(
pos_proposals_list
)
labels
,
label_weights
,
bbox_targets
,
bbox_weights
=
multi_apply
(
all_labels
=
[]
proposal_target_single
,
all_label_weights
=
[]
pos_proposals_list
,
all_bbox_targets
=
[]
neg_proposals_list
,
all_bbox_weights
=
[]
pos_gt_bboxes_list
,
for
img_id
in
range
(
img_per_gpu
):
pos_gt_labels_list
,
pos_proposals
=
pos_proposals_list
[
img_id
]
cfg
=
cfg
,
neg_proposals
=
neg_proposals_list
[
img_id
]
reg_num_classes
=
reg_num_classes
,
pos_gt_bboxes
=
pos_gt_bboxes_list
[
img_id
]
target_means
=
target_means
,
pos_gt_labels
=
pos_gt_labels_list
[
img_id
]
target_stds
=
target_stds
)
debug_img
=
debug_imgs
[
img_id
]
if
cfg
.
debug
else
None
labels
,
label_weights
,
bbox_targets
,
bbox_weights
=
proposal_target_single
(
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_labels
,
reg_num_classes
,
cfg
,
target_means
,
target_stds
)
all_labels
.
append
(
labels
)
all_label_weights
.
append
(
label_weights
)
all_bbox_targets
.
append
(
bbox_targets
)
all_bbox_weights
.
append
(
bbox_weights
)
if
return_list
:
if
concat
:
return
all_labels
,
all_label_weights
,
all_bbox_targets
,
all_bbox_weights
labels
=
torch
.
cat
(
labels
,
0
)
label_weights
=
torch
.
cat
(
label_weights
,
0
)
labels
=
torch
.
cat
(
all_labels
,
0
)
bbox_targets
=
torch
.
cat
(
bbox_targets
,
0
)
label_weights
=
torch
.
cat
(
all_label_weights
,
0
)
bbox_weights
=
torch
.
cat
(
bbox_weights
,
0
)
bbox_targets
=
torch
.
cat
(
all_bbox_targets
,
0
)
bbox_weights
=
torch
.
cat
(
all_bbox_weights
,
0
)
return
labels
,
label_weights
,
bbox_targets
,
bbox_weights
return
labels
,
label_weights
,
bbox_targets
,
bbox_weights
...
@@ -47,8 +36,8 @@ def proposal_target_single(pos_proposals,
...
@@ -47,8 +36,8 @@ def proposal_target_single(pos_proposals,
neg_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_bboxes
,
pos_gt_labels
,
pos_gt_labels
,
reg_num_classes
,
cfg
,
cfg
,
reg_num_classes
=
1
,
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
]):
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
]):
num_pos
=
pos_proposals
.
size
(
0
)
num_pos
=
pos_proposals
.
size
(
0
)
...
@@ -62,8 +51,8 @@ def proposal_target_single(pos_proposals,
...
@@ -62,8 +51,8 @@ def proposal_target_single(pos_proposals,
labels
[:
num_pos
]
=
pos_gt_labels
labels
[:
num_pos
]
=
pos_gt_labels
pos_weight
=
1.0
if
cfg
.
pos_weight
<=
0
else
cfg
.
pos_weight
pos_weight
=
1.0
if
cfg
.
pos_weight
<=
0
else
cfg
.
pos_weight
label_weights
[:
num_pos
]
=
pos_weight
label_weights
[:
num_pos
]
=
pos_weight
pos_bbox_targets
=
bbox
_transform
(
pos_proposals
,
pos_gt_bboxes
,
pos_bbox_targets
=
bbox
2delta
(
pos_proposals
,
pos_gt_bboxes
,
target_means
,
target_stds
)
target_means
,
target_stds
)
bbox_targets
[:
num_pos
,
:]
=
pos_bbox_targets
bbox_targets
[:
num_pos
,
:]
=
pos_bbox_targets
bbox_weights
[:
num_pos
,
:]
=
1
bbox_weights
[:
num_pos
,
:]
=
1
if
num_neg
>
0
:
if
num_neg
>
0
:
...
...
mmdet/core/bbox_ops/sampling.py
View file @
98b20b9b
...
@@ -20,30 +20,36 @@ def random_choice(gallery, num):
...
@@ -20,30 +20,36 @@ def random_choice(gallery, num):
def
bbox_assign
(
proposals
,
def
bbox_assign
(
proposals
,
gt_bboxes
,
gt_bboxes
,
gt_
crowd_
bboxes
=
None
,
gt_bboxes
_ignore
=
None
,
gt_labels
=
None
,
gt_labels
=
None
,
pos_iou_thr
=
0.5
,
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
min_pos_iou
=
.
0
,
min_pos_iou
=
.
0
,
crowd_thr
=-
1
):
crowd_thr
=-
1
):
"""Assign a corresponding gt bbox or background to each proposal/anchor
"""Assign a corresponding gt bbox or background to each proposal/anchor.
This function assign a gt bbox to every proposal, each proposals will be
assigned with -1, 0, or a positive number. -1 means don't care, 0 means
Each proposals will be assigned with `-1`, `0`, or a positive integer.
negative sample, positive number is the index (1-based) of assigned gt.
If gt_crowd_bboxes is not None, proposals which have iof(intersection over foreground)
- -1: don't care
with crowd bboxes over crowd_thr will be ignored
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
If `gt_bboxes_ignore` is specified, bboxes which have iof (intersection
over foreground) with `gt_bboxes_ignore` above `crowd_thr` will be ignored.
Args:
Args:
proposals(Tensor): proposals or RPN anchors, shape (n, 4)
proposals (Tensor): Proposals or RPN anchors, shape (n, 4).
gt_bboxes(Tensor): shape (k, 4)
gt_bboxes (Tensor): Ground truth bboxes, shape (k, 4).
gt_crowd_bboxes(Tensor): shape(m, 4)
gt_bboxes_ignore (Tensor, optional): shape(m, 4).
gt_labels(Tensor, optional): shape (k, )
gt_labels (Tensor, optional): shape (k, ).
pos_iou_thr(float): iou threshold for positive bboxes
pos_iou_thr (float): IoU threshold for positive bboxes.
neg_iou_thr(float or tuple): iou threshold for negative bboxes
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
min_pos_iou(float): minimum iou for a bbox to be considered as a positive bbox,
min_pos_iou (float): Minimum iou for a bbox to be considered as a
for RPN, it is usually set as 0, for Fast R-CNN,
positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
it is usually set as pos_iou_thr
it is usually set as pos_iou_thr
crowd_thr: ignore proposals which have iof(intersection over foreground) with
crowd_thr (float): IoF threshold for ignoring bboxes. Negative value
crowd bboxes over crowd_thr
for not ignoring any bboxes.
Returns:
Returns:
tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
"""
"""
...
@@ -54,20 +60,20 @@ def bbox_assign(proposals,
...
@@ -54,20 +60,20 @@ def bbox_assign(proposals,
raise
ValueError
(
'No gt bbox or proposals'
)
raise
ValueError
(
'No gt bbox or proposals'
)
# ignore proposals according to crowd bboxes
# ignore proposals according to crowd bboxes
if
(
crowd_thr
>
0
)
and
(
gt_
crowd_
bboxes
is
if
(
crowd_thr
>
0
)
and
(
gt_bboxes
_ignore
is
not
None
)
and
(
gt_
crowd_
bboxes
.
numel
()
>
0
):
not
None
)
and
(
gt_bboxes
_ignore
.
numel
()
>
0
):
crowd_overlaps
=
bbox_overlaps
(
proposals
,
gt_
crowd_
bboxes
,
mode
=
'iof'
)
crowd_overlaps
=
bbox_overlaps
(
proposals
,
gt_bboxes
_ignore
,
mode
=
'iof'
)
crowd_max_overlaps
,
_
=
crowd_overlaps
.
max
(
dim
=
1
)
crowd_max_overlaps
,
_
=
crowd_overlaps
.
max
(
dim
=
1
)
crowd_bboxes_inds
=
torch
.
nonzero
(
crowd_bboxes_inds
=
torch
.
nonzero
(
crowd_max_overlaps
>
crowd_thr
).
long
()
crowd_max_overlaps
>
crowd_thr
).
long
()
if
crowd_bboxes_inds
.
numel
()
>
0
:
if
crowd_bboxes_inds
.
numel
()
>
0
:
overlaps
[
crowd_bboxes_inds
,
:]
=
-
1
overlaps
[
crowd_bboxes_inds
,
:]
=
-
1
return
bbox_assign_
via
_overlaps
(
overlaps
,
gt_labels
,
pos_iou_thr
,
return
bbox_assign_
wrt
_overlaps
(
overlaps
,
gt_labels
,
pos_iou_thr
,
neg_iou_thr
,
min_pos_iou
)
neg_iou_thr
,
min_pos_iou
)
def
bbox_assign_
via
_overlaps
(
overlaps
,
def
bbox_assign_
wrt
_overlaps
(
overlaps
,
gt_labels
=
None
,
gt_labels
=
None
,
pos_iou_thr
=
0.5
,
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
...
...
mmdet/core/bbox_ops/transforms.py
View file @
98b20b9b
...
@@ -3,7 +3,7 @@ import numpy as np
...
@@ -3,7 +3,7 @@ import numpy as np
import
torch
import
torch
def
bbox
_transform
(
proposals
,
gt
,
means
=
[
0
,
0
,
0
,
0
],
stds
=
[
1
,
1
,
1
,
1
]):
def
bbox
2delta
(
proposals
,
gt
,
means
=
[
0
,
0
,
0
,
0
],
stds
=
[
1
,
1
,
1
,
1
]):
assert
proposals
.
size
()
==
gt
.
size
()
assert
proposals
.
size
()
==
gt
.
size
()
proposals
=
proposals
.
float
()
proposals
=
proposals
.
float
()
...
@@ -31,12 +31,12 @@ def bbox_transform(proposals, gt, means=[0, 0, 0, 0], stds=[1, 1, 1, 1]):
...
@@ -31,12 +31,12 @@ def bbox_transform(proposals, gt, means=[0, 0, 0, 0], stds=[1, 1, 1, 1]):
return
deltas
return
deltas
def
bbox_transform_inv
(
rois
,
def
delta2bbox
(
rois
,
deltas
,
deltas
,
means
=
[
0
,
0
,
0
,
0
],
means
=
[
0
,
0
,
0
,
0
],
stds
=
[
1
,
1
,
1
,
1
],
stds
=
[
1
,
1
,
1
,
1
],
max_shape
=
None
,
max_shape
=
None
,
wh_ratio_clip
=
16
/
1000
):
wh_ratio_clip
=
16
/
1000
):
means
=
deltas
.
new_tensor
(
means
).
repeat
(
1
,
deltas
.
size
(
1
)
//
4
)
means
=
deltas
.
new_tensor
(
means
).
repeat
(
1
,
deltas
.
size
(
1
)
//
4
)
stds
=
deltas
.
new_tensor
(
stds
).
repeat
(
1
,
deltas
.
size
(
1
)
//
4
)
stds
=
deltas
.
new_tensor
(
stds
).
repeat
(
1
,
deltas
.
size
(
1
)
//
4
)
denorm_deltas
=
deltas
*
stds
+
means
denorm_deltas
=
deltas
*
stds
+
means
...
@@ -69,10 +69,14 @@ def bbox_transform_inv(rois,
...
@@ -69,10 +69,14 @@ def bbox_transform_inv(rois,
def
bbox_flip
(
bboxes
,
img_shape
):
def
bbox_flip
(
bboxes
,
img_shape
):
"""Flip bboxes horizontally
"""Flip bboxes horizontally.
Args:
Args:
bboxes(Tensor): shape (..., 4*k)
bboxes(Tensor or ndarray): Shape (..., 4*k)
img_shape(Tensor): image shape
img_shape(tuple): Image shape.
Returns:
Same type as `bboxes`: Flipped bboxes.
"""
"""
if
isinstance
(
bboxes
,
torch
.
Tensor
):
if
isinstance
(
bboxes
,
torch
.
Tensor
):
assert
bboxes
.
shape
[
-
1
]
%
4
==
0
assert
bboxes
.
shape
[
-
1
]
%
4
==
0
...
@@ -101,8 +105,11 @@ def bbox_mapping_back(bboxes, img_shape, scale_factor, flip):
...
@@ -101,8 +105,11 @@ def bbox_mapping_back(bboxes, img_shape, scale_factor, flip):
def
bbox2roi
(
bbox_list
):
def
bbox2roi
(
bbox_list
):
"""Convert a list of bboxes to roi format.
"""Convert a list of bboxes to roi format.
Args:
Args:
bbox_list (Tensor): a list of bboxes corresponding to a list of images
bbox_list (list[Tensor]): a list of bboxes corresponding to a batch
of images.
Returns:
Returns:
Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2]
Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2]
"""
"""
...
@@ -129,11 +136,13 @@ def roi2bbox(rois):
...
@@ -129,11 +136,13 @@ def roi2bbox(rois):
def
bbox2result
(
bboxes
,
labels
,
num_classes
):
def
bbox2result
(
bboxes
,
labels
,
num_classes
):
"""Convert detection results to a list of numpy arrays
"""Convert detection results to a list of numpy arrays.
Args:
Args:
bboxes (Tensor): shape (n, 5)
bboxes (Tensor): shape (n, 5)
labels (Tensor): shape (n, )
labels (Tensor): shape (n, )
num_classes (int): class number, including background class
num_classes (int): class number, including background class
Returns:
Returns:
list(ndarray): bbox results of each class
list(ndarray): bbox results of each class
"""
"""
...
...
mmdet/core/eval/eval_hooks.py
View file @
98b20b9b
...
@@ -11,7 +11,6 @@ from pycocotools.cocoeval import COCOeval
...
@@ -11,7 +11,6 @@ from pycocotools.cocoeval import COCOeval
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
.coco_utils
import
results2json
,
fast_eval_recall
from
.coco_utils
import
results2json
,
fast_eval_recall
from
.recall
import
eval_recalls
from
..parallel
import
scatter
from
..parallel
import
scatter
from
mmdet
import
datasets
from
mmdet
import
datasets
from
mmdet.datasets.loader
import
collate
from
mmdet.datasets.loader
import
collate
...
...
mmdet/core/eval/mean_ap.py
View file @
98b20b9b
...
@@ -9,9 +9,9 @@ def average_precision(recalls, precisions, mode='area'):
...
@@ -9,9 +9,9 @@ def average_precision(recalls, precisions, mode='area'):
"""Calculate average precision (for single or multiple scales).
"""Calculate average precision (for single or multiple scales).
Args:
Args:
recalls(ndarray): shape (num_scales, num_dets) or (num_dets, )
recalls
(ndarray): shape (num_scales, num_dets) or (num_dets, )
precisions(ndarray): shape (num_scales, num_dets) or (num_dets, )
precisions
(ndarray): shape (num_scales, num_dets) or (num_dets, )
mode(str): 'area' or '11points', 'area' means calculating the area
mode
(str): 'area' or '11points', 'area' means calculating the area
under precision-recall curve, '11points' means calculating
under precision-recall curve, '11points' means calculating
the average precision of recalls at [0, 0.1, ..., 1]
the average precision of recalls at [0, 0.1, ..., 1]
...
@@ -60,11 +60,11 @@ def tpfp_imagenet(det_bboxes,
...
@@ -60,11 +60,11 @@ def tpfp_imagenet(det_bboxes,
"""Check if detected bboxes are true positive or false positive.
"""Check if detected bboxes are true positive or false positive.
Args:
Args:
det_bbox(ndarray): the detected bbox
det_bbox
(ndarray): the detected bbox
gt_bboxes(ndarray): ground truth bboxes of this image
gt_bboxes
(ndarray): ground truth bboxes of this image
gt_ignore(ndarray): indicate if gts are ignored for evaluation or not
gt_ignore
(ndarray): indicate if gts are ignored for evaluation or not
default_iou_thr(float): the iou thresholds for medium and large bboxes
default_iou_thr
(float): the iou thresholds for medium and large bboxes
area_ranges(list or None): gt bbox area ranges
area_ranges
(list or None): gt bbox area ranges
Returns:
Returns:
tuple: two arrays (tp, fp) whose elements are 0 and 1
tuple: two arrays (tp, fp) whose elements are 0 and 1
...
@@ -115,10 +115,10 @@ def tpfp_imagenet(det_bboxes,
...
@@ -115,10 +115,10 @@ def tpfp_imagenet(det_bboxes,
max_iou
=
ious
[
i
,
j
]
max_iou
=
ious
[
i
,
j
]
matched_gt
=
j
matched_gt
=
j
# there are 4 cases for a det bbox:
# there are 4 cases for a det bbox:
# 1. t
his det bbox
matches a gt, tp = 1, fp = 0
# 1.
i
t matches a gt, tp = 1, fp = 0
# 2. t
his det bbox
matches an ignored gt, tp = 0, fp = 0
# 2.
i
t matches an ignored gt, tp = 0, fp = 0
# 3. t
his det bbox
matches no gt and within area range, tp = 0, fp = 1
# 3.
i
t matches no gt and within area range, tp = 0, fp = 1
# 4. t
his det bbox
matches no gt but is beyond area range, tp = 0, fp = 0
# 4.
i
t matches no gt but is beyond area range, tp = 0, fp = 0
if
matched_gt
>=
0
:
if
matched_gt
>=
0
:
gt_covered
[
matched_gt
]
=
1
gt_covered
[
matched_gt
]
=
1
if
not
(
gt_ignore
[
matched_gt
]
or
gt_area_ignore
[
matched_gt
]):
if
not
(
gt_ignore
[
matched_gt
]
or
gt_area_ignore
[
matched_gt
]):
...
@@ -137,10 +137,10 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
...
@@ -137,10 +137,10 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
"""Check if detected bboxes are true positive or false positive.
"""Check if detected bboxes are true positive or false positive.
Args:
Args:
det_bbox(ndarray): the detected bbox
det_bbox
(ndarray): the detected bbox
gt_bboxes(ndarray): ground truth bboxes of this image
gt_bboxes
(ndarray): ground truth bboxes of this image
gt_ignore(ndarray): indicate if gts are ignored for evaluation or not
gt_ignore
(ndarray): indicate if gts are ignored for evaluation or not
iou_thr(float): the iou thresholds
iou_thr
(float): the iou thresholds
Returns:
Returns:
tuple: (tp, fp), two arrays whose elements are 0 and 1
tuple: (tp, fp), two arrays whose elements are 0 and 1
...
@@ -227,15 +227,16 @@ def eval_map(det_results,
...
@@ -227,15 +227,16 @@ def eval_map(det_results,
"""Evaluate mAP of a dataset.
"""Evaluate mAP of a dataset.
Args:
Args:
det_results(list): a list of list, [[cls1_det, cls2_det, ...], ...]
det_results (list): a list of list, [[cls1_det, cls2_det, ...], ...]
gt_bboxes(list): ground truth bboxes of each image, a list of K*4 array
gt_bboxes (list): ground truth bboxes of each image, a list of K*4
gt_labels(list): ground truth labels of each image, a list of K array
array.
gt_ignore(list): gt ignore indicators of each image, a list of K array
gt_labels (list): ground truth labels of each image, a list of K array
scale_ranges(list, optional): [(min1, max1), (min2, max2), ...]
gt_ignore (list): gt ignore indicators of each image, a list of K array
iou_thr(float): IoU threshold
scale_ranges (list, optional): [(min1, max1), (min2, max2), ...]
dataset(None or str): dataset name, there are minor differences in
iou_thr (float): IoU threshold
dataset (None or str): dataset name, there are minor differences in
metrics for different datsets, e.g. "voc07", "imagenet_det", etc.
metrics for different datsets, e.g. "voc07", "imagenet_det", etc.
print_summary(bool): whether to print the mAP summary
print_summary
(bool): whether to print the mAP summary
Returns:
Returns:
tuple: (mAP, [dict, dict, ...])
tuple: (mAP, [dict, dict, ...])
...
@@ -265,7 +266,8 @@ def eval_map(det_results,
...
@@ -265,7 +266,8 @@ def eval_map(det_results,
area_ranges
)
for
j
in
range
(
len
(
cls_dets
))
area_ranges
)
for
j
in
range
(
len
(
cls_dets
))
]
]
tp
,
fp
=
tuple
(
zip
(
*
tpfp
))
tp
,
fp
=
tuple
(
zip
(
*
tpfp
))
# calculate gt number of each scale, gts ignored or beyond scale are not counted
# calculate gt number of each scale, gts ignored or beyond scale
# are not counted
num_gts
=
np
.
zeros
(
num_scales
,
dtype
=
int
)
num_gts
=
np
.
zeros
(
num_scales
,
dtype
=
int
)
for
j
,
bbox
in
enumerate
(
cls_gts
):
for
j
,
bbox
in
enumerate
(
cls_gts
):
if
area_ranges
is
None
:
if
area_ranges
is
None
:
...
...
mmdet/core/losses/losses.py
View file @
98b20b9b
...
@@ -30,13 +30,13 @@ def sigmoid_focal_loss(pred,
...
@@ -30,13 +30,13 @@ def sigmoid_focal_loss(pred,
weight
,
weight
,
gamma
=
2.0
,
gamma
=
2.0
,
alpha
=
0.25
,
alpha
=
0.25
,
size_average
=
True
):
reduction
=
'elementwise_mean'
):
pred_sigmoid
=
pred
.
sigmoid
()
pred_sigmoid
=
pred
.
sigmoid
()
pt
=
(
1
-
pred_sigmoid
)
*
target
+
pred_sigmoid
*
(
1
-
target
)
pt
=
(
1
-
pred_sigmoid
)
*
target
+
pred_sigmoid
*
(
1
-
target
)
weight
=
(
alpha
*
target
+
(
1
-
alpha
)
*
(
1
-
target
))
*
weight
weight
=
(
alpha
*
target
+
(
1
-
alpha
)
*
(
1
-
target
))
*
weight
weight
=
weight
*
pt
.
pow
(
gamma
)
weight
=
weight
*
pt
.
pow
(
gamma
)
return
F
.
binary_cross_entropy_with_logits
(
return
F
.
binary_cross_entropy_with_logits
(
pred
,
target
,
weight
,
size_average
=
size_average
)
pred
,
target
,
weight
,
size_average
=
reduction
)
def
weighted_sigmoid_focal_loss
(
pred
,
def
weighted_sigmoid_focal_loss
(
pred
,
...
...
mmdet/core/rpn_ops/anchor_target.py
View file @
98b20b9b
import
torch
import
torch
from
..bbox_ops
import
bbox_assign
,
bbox
_transform
,
bbox_sampling
from
..bbox_ops
import
bbox_assign
,
bbox
2delta
,
bbox_sampling
from
..utils
import
multi_apply
from
..utils
import
multi_apply
...
@@ -99,8 +99,8 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -99,8 +99,8 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
if
len
(
pos_inds
)
>
0
:
if
len
(
pos_inds
)
>
0
:
pos_anchors
=
anchors
[
pos_inds
,
:]
pos_anchors
=
anchors
[
pos_inds
,
:]
pos_gt_bbox
=
gt_bboxes
[
assigned_gt_inds
[
pos_inds
]
-
1
,
:]
pos_gt_bbox
=
gt_bboxes
[
assigned_gt_inds
[
pos_inds
]
-
1
,
:]
pos_bbox_targets
=
bbox
_transform
(
pos_anchors
,
pos_gt_bbox
,
pos_bbox_targets
=
bbox
2delta
(
pos_anchors
,
pos_gt_bbox
,
target_means
,
target_means
,
target_stds
)
target_stds
)
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_weights
[
pos_inds
,
:]
=
1.0
bbox_weights
[
pos_inds
,
:]
=
1.0
labels
[
pos_inds
]
=
1
labels
[
pos_inds
]
=
1
...
...
mmdet/models/backbones/resnet.py
View file @
98b20b9b
...
@@ -27,7 +27,7 @@ class BasicBlock(nn.Module):
...
@@ -27,7 +27,7 @@ class BasicBlock(nn.Module):
stride
=
1
,
stride
=
1
,
dilation
=
1
,
dilation
=
1
,
downsample
=
None
,
downsample
=
None
,
style
=
'
fb
'
):
style
=
'
pytorch
'
):
super
(
BasicBlock
,
self
).
__init__
()
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
...
@@ -66,15 +66,16 @@ class Bottleneck(nn.Module):
...
@@ -66,15 +66,16 @@ class Bottleneck(nn.Module):
stride
=
1
,
stride
=
1
,
dilation
=
1
,
dilation
=
1
,
downsample
=
None
,
downsample
=
None
,
style
=
'
fb
'
,
style
=
'
pytorch
'
,
with_cp
=
False
):
with_cp
=
False
):
"""Bottleneck block
"""Bottleneck block.
if style is "fb", the stride-two layer is the 3x3 conv layer,
if style is "msra", the stride-two layer is the first 1x1 conv layer
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
"""
super
(
Bottleneck
,
self
).
__init__
()
super
(
Bottleneck
,
self
).
__init__
()
assert
style
in
[
'
fb'
,
'msra
'
]
assert
style
in
[
'
pytorch'
,
'caffe
'
]
if
style
==
'
fb
'
:
if
style
==
'
pytorch
'
:
conv1_stride
=
1
conv1_stride
=
1
conv2_stride
=
stride
conv2_stride
=
stride
else
:
else
:
...
@@ -141,7 +142,7 @@ def make_res_layer(block,
...
@@ -141,7 +142,7 @@ def make_res_layer(block,
blocks
,
blocks
,
stride
=
1
,
stride
=
1
,
dilation
=
1
,
dilation
=
1
,
style
=
'
fb
'
,
style
=
'
pytorch
'
,
with_cp
=
False
):
with_cp
=
False
):
downsample
=
None
downsample
=
None
if
stride
!=
1
or
inplanes
!=
planes
*
block
.
expansion
:
if
stride
!=
1
or
inplanes
!=
planes
*
block
.
expansion
:
...
@@ -175,7 +176,12 @@ def make_res_layer(block,
...
@@ -175,7 +176,12 @@ def make_res_layer(block,
class
ResHead
(
nn
.
Module
):
class
ResHead
(
nn
.
Module
):
def
__init__
(
self
,
block
,
num_blocks
,
stride
=
2
,
dilation
=
1
,
style
=
'fb'
):
def
__init__
(
self
,
block
,
num_blocks
,
stride
=
2
,
dilation
=
1
,
style
=
'pytorch'
):
self
.
layer4
=
make_res_layer
(
self
.
layer4
=
make_res_layer
(
block
,
block
,
1024
,
1024
,
...
@@ -198,7 +204,7 @@ class ResNet(nn.Module):
...
@@ -198,7 +204,7 @@ class ResNet(nn.Module):
dilations
=
(
1
,
1
,
1
,
1
),
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
frozen_stages
=-
1
,
frozen_stages
=-
1
,
style
=
'
fb
'
,
style
=
'
pytorch
'
,
sync_bn
=
False
,
sync_bn
=
False
,
with_cp
=
False
,
with_cp
=
False
,
strict_frozen
=
False
):
strict_frozen
=
False
):
...
@@ -237,7 +243,7 @@ class ResNet(nn.Module):
...
@@ -237,7 +243,7 @@ class ResNet(nn.Module):
style
=
self
.
style
,
style
=
self
.
style
,
with_cp
=
with_cp
)
with_cp
=
with_cp
)
self
.
inplanes
=
planes
*
block
.
expansion
self
.
inplanes
=
planes
*
block
.
expansion
se
tattr
(
self
,
layer_name
,
res_layer
)
se
lf
.
add_module
(
layer_name
,
res_layer
)
self
.
res_layers
.
append
(
layer_name
)
self
.
res_layers
.
append
(
layer_name
)
self
.
feat_dim
=
block
.
expansion
*
64
*
2
**
(
len
(
layers
)
-
1
)
self
.
feat_dim
=
block
.
expansion
*
64
*
2
**
(
len
(
layers
)
-
1
)
self
.
with_cp
=
with_cp
self
.
with_cp
=
with_cp
...
@@ -314,7 +320,7 @@ def resnet(depth,
...
@@ -314,7 +320,7 @@ def resnet(depth,
dilations
=
(
1
,
1
,
1
,
1
),
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
2
,
),
out_indices
=
(
2
,
),
frozen_stages
=-
1
,
frozen_stages
=-
1
,
style
=
'
fb
'
,
style
=
'
pytorch
'
,
sync_bn
=
False
,
sync_bn
=
False
,
with_cp
=
False
,
with_cp
=
False
,
strict_frozen
=
False
):
strict_frozen
=
False
):
...
...
mmdet/models/bbox_heads/bbox_head.py
View file @
98b20b9b
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmdet.core
import
(
bbox_transform_inv
,
multiclass_nms
,
bbox_target
,
from
mmdet.core
import
(
delta2bbox
,
multiclass_nms
,
bbox_target
,
weighted_cross_entropy
,
weighted_smoothl1
,
accuracy
)
weighted_cross_entropy
,
weighted_smoothl1
,
accuracy
)
...
@@ -101,9 +101,8 @@ class BBoxHead(nn.Module):
...
@@ -101,9 +101,8 @@ class BBoxHead(nn.Module):
scores
=
F
.
softmax
(
cls_score
,
dim
=
1
)
if
cls_score
is
not
None
else
None
scores
=
F
.
softmax
(
cls_score
,
dim
=
1
)
if
cls_score
is
not
None
else
None
if
bbox_pred
is
not
None
:
if
bbox_pred
is
not
None
:
bboxes
=
bbox_transform_inv
(
rois
[:,
1
:],
bbox_pred
,
bboxes
=
delta2bbox
(
rois
[:,
1
:],
bbox_pred
,
self
.
target_means
,
self
.
target_means
,
self
.
target_stds
,
self
.
target_stds
,
img_shape
)
img_shape
)
else
:
else
:
bboxes
=
rois
[:,
1
:]
bboxes
=
rois
[:,
1
:]
# TODO: add clip here
# TODO: add clip here
...
...
mmdet/models/roi_extractors/single_level.py
View file @
98b20b9b
...
@@ -41,10 +41,10 @@ class SingleLevelRoI(nn.Module):
...
@@ -41,10 +41,10 @@ class SingleLevelRoI(nn.Module):
def
map_roi_levels
(
self
,
rois
,
num_levels
):
def
map_roi_levels
(
self
,
rois
,
num_levels
):
"""Map rois to corresponding feature levels (0-based) by scales.
"""Map rois to corresponding feature levels (0-based) by scales.
scale < finest_scale: level 0
-
scale < finest_scale: level 0
finest_scale <= scale < finest_scale * 2: level 1
-
finest_scale <= scale < finest_scale * 2: level 1
finest_scale * 2 <= scale < finest_scale * 4: level 2
-
finest_scale * 2 <= scale < finest_scale * 4: level 2
scale >= finest_scale * 4: level 3
-
scale >= finest_scale * 4: level 3
"""
"""
scale
=
torch
.
sqrt
(
scale
=
torch
.
sqrt
(
(
rois
[:,
3
]
-
rois
[:,
1
]
+
1
)
*
(
rois
[:,
4
]
-
rois
[:,
2
]
+
1
))
(
rois
[:,
3
]
-
rois
[:,
1
]
+
1
)
*
(
rois
[:,
4
]
-
rois
[:,
2
]
+
1
))
...
@@ -52,12 +52,13 @@ class SingleLevelRoI(nn.Module):
...
@@ -52,12 +52,13 @@ class SingleLevelRoI(nn.Module):
target_lvls
=
target_lvls
.
clamp
(
min
=
0
,
max
=
num_levels
-
1
).
long
()
target_lvls
=
target_lvls
.
clamp
(
min
=
0
,
max
=
num_levels
-
1
).
long
()
return
target_lvls
return
target_lvls
def
sample_proposals
(
self
,
proposals
,
gt_bboxes
,
gt_
crowds
,
gt_labels
,
def
sample_proposals
(
self
,
proposals
,
gt_bboxes
,
gt_
bboxes_ignore
,
cfg
):
gt_labels
,
cfg
):
proposals
=
proposals
[:,
:
4
]
proposals
=
proposals
[:,
:
4
]
assigned_gt_inds
,
assigned_labels
,
argmax_overlaps
,
max_overlaps
=
\
assigned_gt_inds
,
assigned_labels
,
argmax_overlaps
,
max_overlaps
=
\
bbox_assign
(
proposals
,
gt_bboxes
,
gt_crowds
,
gt_labels
,
bbox_assign
(
proposals
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
.
pos_iou_thr
,
cfg
.
neg_iou_thr
,
cfg
.
min_pos_iou
,
cfg
.
crowd_thr
)
cfg
.
pos_iou_thr
,
cfg
.
neg_iou_thr
,
cfg
.
min_pos_iou
,
cfg
.
crowd_thr
)
if
cfg
.
add_gt_as_proposals
:
if
cfg
.
add_gt_as_proposals
:
proposals
=
torch
.
cat
([
gt_bboxes
,
proposals
],
dim
=
0
)
proposals
=
torch
.
cat
([
gt_bboxes
,
proposals
],
dim
=
0
)
...
@@ -80,7 +81,8 @@ class SingleLevelRoI(nn.Module):
...
@@ -80,7 +81,8 @@ class SingleLevelRoI(nn.Module):
pos_gt_bboxes
=
gt_bboxes
[
pos_assigned_gt_inds
,
:]
pos_gt_bboxes
=
gt_bboxes
[
pos_assigned_gt_inds
,
:]
pos_gt_labels
=
assigned_labels
[
pos_inds
]
pos_gt_labels
=
assigned_labels
[
pos_inds
]
return
(
pos_proposals
,
neg_proposals
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
pos_gt_labels
)
return
(
pos_proposals
,
neg_proposals
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
pos_gt_labels
)
def
forward
(
self
,
feats
,
rois
):
def
forward
(
self
,
feats
,
rois
):
"""Extract roi features with the roi layer. If multiple feature levels
"""Extract roi features with the roi layer. If multiple feature levels
...
...
mmdet/models/rpn_heads/rpn_head.py
View file @
98b20b9b
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
bbox_transform_inv
,
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
delta2bbox
,
multi_apply
,
weighted_cross_entropy
,
weighted_smoothl1
,
multi_apply
,
weighted_cross_entropy
,
weighted_smoothl1
,
weighted_binary_cross_entropy
)
weighted_binary_cross_entropy
)
from
mmdet.ops
import
nms
from
mmdet.ops
import
nms
...
@@ -225,9 +225,8 @@ class RPNHead(nn.Module):
...
@@ -225,9 +225,8 @@ class RPNHead(nn.Module):
rpn_bbox_pred
=
rpn_bbox_pred
[
order
,
:]
rpn_bbox_pred
=
rpn_bbox_pred
[
order
,
:]
anchors
=
anchors
[
order
,
:]
anchors
=
anchors
[
order
,
:]
scores
=
scores
[
order
]
scores
=
scores
[
order
]
proposals
=
bbox_transform_inv
(
anchors
,
rpn_bbox_pred
,
proposals
=
delta2bbox
(
anchors
,
rpn_bbox_pred
,
self
.
target_means
,
self
.
target_means
,
self
.
target_stds
,
self
.
target_stds
,
img_shape
)
img_shape
)
w
=
proposals
[:,
2
]
-
proposals
[:,
0
]
+
1
w
=
proposals
[:,
2
]
-
proposals
[:,
0
]
+
1
h
=
proposals
[:,
3
]
-
proposals
[:,
1
]
+
1
h
=
proposals
[:,
3
]
-
proposals
[:,
1
]
+
1
valid_inds
=
torch
.
nonzero
((
w
>=
cfg
.
min_bbox_size
)
&
valid_inds
=
torch
.
nonzero
((
w
>=
cfg
.
min_bbox_size
)
&
...
...
tools/configs/r50_fpn_frcnn_1x.py
View file @
98b20b9b
...
@@ -8,7 +8,7 @@ model = dict(
...
@@ -8,7 +8,7 @@ model = dict(
num_stages
=
4
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
frozen_stages
=
1
,
frozen_stages
=
1
,
style
=
'
fb
'
),
style
=
'
pytorch
'
),
neck
=
dict
(
neck
=
dict
(
type
=
'FPN'
,
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
in_channels
=
[
256
,
512
,
1024
,
2048
],
...
...
tools/configs/r50_fpn_maskrcnn_1x.py
View file @
98b20b9b
...
@@ -8,7 +8,7 @@ model = dict(
...
@@ -8,7 +8,7 @@ model = dict(
num_stages
=
4
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
frozen_stages
=
1
,
frozen_stages
=
1
,
style
=
'
fb
'
),
style
=
'
pytorch
'
),
neck
=
dict
(
neck
=
dict
(
type
=
'FPN'
,
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
in_channels
=
[
256
,
512
,
1024
,
2048
],
...
...
tools/configs/r50_fpn_rpn_1x.py
View file @
98b20b9b
...
@@ -8,7 +8,7 @@ model = dict(
...
@@ -8,7 +8,7 @@ model = dict(
num_stages
=
4
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
frozen_stages
=
1
,
frozen_stages
=
1
,
style
=
'
fb
'
),
style
=
'
pytorch
'
),
neck
=
dict
(
neck
=
dict
(
type
=
'FPN'
,
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
in_channels
=
[
256
,
512
,
1024
,
2048
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment