Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
427c8902
Commit
427c8902
authored
Sep 25, 2018
by
pangjm
Browse files
add Faster RCNN & Mask RCNN training API and some test related
parent
65642939
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
209 additions
and
79 deletions
+209
-79
mmdet/core/bbox_ops/__init__.py
mmdet/core/bbox_ops/__init__.py
+2
-3
mmdet/core/losses/losses.py
mmdet/core/losses/losses.py
+1
-1
mmdet/core/mask_ops/__init__.py
mmdet/core/mask_ops/__init__.py
+2
-2
mmdet/core/mask_ops/mask_target.py
mmdet/core/mask_ops/mask_target.py
+12
-8
mmdet/core/mask_ops/utils.py
mmdet/core/mask_ops/utils.py
+10
-11
mmdet/models/detectors/__init__.py
mmdet/models/detectors/__init__.py
+3
-1
mmdet/models/detectors/faster_rcnn.py
mmdet/models/detectors/faster_rcnn.py
+23
-0
mmdet/models/detectors/mask_rcnn.py
mmdet/models/detectors/mask_rcnn.py
+27
-0
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+91
-48
mmdet/models/mask_heads/fcn_mask_head.py
mmdet/models/mask_heads/fcn_mask_head.py
+7
-5
mmdet/models/roi_extractors/single_level.py
mmdet/models/roi_extractors/single_level.py
+31
-0
No files found.
mmdet/core/bbox_ops/__init__.py
View file @
427c8902
from
.geometry
import
bbox_overlaps
from
.sampling
import
(
random_choice
,
bbox_assign
,
bbox_assign_via_overlaps
,
bbox_sampling
,
sample_positives
,
sample_negatives
,
sample_proposals
)
bbox_sampling
,
sample_positives
,
sample_negatives
)
from
.transforms
import
(
bbox_transform
,
bbox_transform_inv
,
bbox_flip
,
bbox_mapping
,
bbox_mapping_back
,
bbox2roi
,
roi2bbox
,
bbox2result
)
...
...
@@ -12,5 +11,5 @@ __all__ = [
'bbox_assign_via_overlaps'
,
'bbox_sampling'
,
'sample_positives'
,
'sample_negatives'
,
'bbox_transform'
,
'bbox_transform_inv'
,
'bbox_flip'
,
'bbox_mapping'
,
'bbox_mapping_back'
,
'bbox2roi'
,
'roi2bbox'
,
'bbox2result'
,
'bbox_target'
,
'sample_proposals'
'bbox_target'
]
mmdet/core/losses/losses.py
View file @
427c8902
...
...
@@ -58,7 +58,7 @@ def mask_cross_entropy(pred, target, label):
inds
=
torch
.
arange
(
0
,
num_rois
,
dtype
=
torch
.
long
,
device
=
pred
.
device
)
pred_slice
=
pred
[
inds
,
label
].
squeeze
(
1
)
return
F
.
binary_cross_entropy_with_logits
(
pred_slice
,
target
,
reduction
=
'
sum
'
)[
None
]
pred_slice
,
target
,
reduction
=
'
elementwise_mean
'
)[
None
]
def
weighted_mask_cross_entropy
(
pred
,
target
,
weight
,
label
):
...
...
mmdet/core/mask_ops/__init__.py
View file @
427c8902
from
.segms
import
(
flip_segms
,
polys_to_mask
,
mask_to_bbox
,
polys_to_mask_wrt_box
,
polys_to_boxes
,
rle_mask_voting
,
rle_mask_nms
,
rle_masks_to_boxes
)
from
.utils
import
split_combined_
gt_
polys
from
.utils
import
split_combined_polys
from
.mask_target
import
mask_target
__all__
=
[
'flip_segms'
,
'polys_to_mask'
,
'mask_to_bbox'
,
'polys_to_mask_wrt_box'
,
'polys_to_boxes'
,
'rle_mask_voting'
,
'rle_mask_nms'
,
'rle_masks_to_boxes'
,
'split_combined_
gt_
polys'
,
'mask_target'
'split_combined_polys'
,
'mask_target'
]
mmdet/core/mask_ops/mask_target.py
View file @
427c8902
...
...
@@ -4,27 +4,31 @@ import numpy as np
from
.segms
import
polys_to_mask_wrt_box
def
mask_target
(
pos_proposals_list
,
pos_assigned_gt_inds_list
,
gt_polys_list
,
img_meta
,
cfg
):
def
mask_target
(
pos_proposals_list
,
pos_assigned_gt_inds_list
,
gt_polys_list
,
img_meta
,
cfg
):
cfg_list
=
[
cfg
for
_
in
range
(
len
(
pos_proposals_list
))]
img_metas
=
[
img_meta
for
_
in
range
(
len
(
pos_proposals_list
))]
mask_targets
=
map
(
mask_target_single
,
pos_proposals_list
,
pos_assigned_gt_inds_list
,
gt_polys_list
,
img_meta
s
,
pos_assigned_gt_inds_list
,
gt_polys_list
,
img_meta
,
cfg_list
)
mask_targets
=
torch
.
cat
(
tuple
(
mask_targets
),
dim
=
0
)
return
mask_targets
def
mask_target_single
(
pos_proposals
,
pos_assigned_gt_inds
,
gt_polys
,
img_meta
,
cfg
):
def
mask_target_single
(
pos_proposals
,
pos_assigned_gt_inds
,
gt_polys
,
img_meta
,
cfg
):
mask_size
=
cfg
.
mask_size
num_pos
=
pos_proposals
.
size
(
0
)
mask_targets
=
pos_proposals
.
new_zeros
((
num_pos
,
mask_size
,
mask_size
))
if
num_pos
>
0
:
pos_proposals
=
pos_proposals
.
cpu
().
numpy
()
pos_assigned_gt_inds
=
pos_assigned_gt_inds
.
cpu
().
numpy
()
scale_factor
=
img_meta
[
'scale_factor'
]
[
0
].
cpu
().
numpy
()
scale_factor
=
img_meta
[
'scale_factor'
]
for
i
in
range
(
num_pos
):
bbox
=
pos_proposals
[
i
,
:]
/
scale_factor
polys
=
gt_polys
[
pos_assigned_gt_inds
[
i
]]
...
...
mmdet/core/mask_ops/utils.py
View file @
427c8902
import
mmcv
def
split_combined_
gt_
polys
(
gt_
polys
,
gt_
poly_lens
,
num_
polys_per_mask
):
def
split_combined_polys
(
polys
,
poly_lens
,
polys_per_mask
):
"""Split the combined 1-D polys into masks.
A mask is represented as a list of polys, and a poly is represented as
...
...
@@ -9,9 +9,9 @@ def split_combined_gt_polys(gt_polys, gt_poly_lens, num_polys_per_mask):
tensor. Here we need to split the tensor into original representations.
Args:
gt_
polys (list): a list (length = image num) of 1-D tensors
gt_
poly_lens (list): a list (length = image num) of poly length
num_
polys_per_mask (list): a list (length = image num) of poly number
polys (list): a list (length = image num) of 1-D tensors
poly_lens (list): a list (length = image num) of poly length
polys_per_mask (list): a list (length = image num) of poly number
of each mask
Returns:
...
...
@@ -19,13 +19,12 @@ def split_combined_gt_polys(gt_polys, gt_poly_lens, num_polys_per_mask):
list (length = poly num) of numpy array
"""
mask_polys_list
=
[]
for
img_id
in
range
(
len
(
gt_polys
)):
gt_polys_single
=
gt_polys
[
img_id
].
cpu
().
numpy
()
gt_polys_lens_single
=
gt_poly_lens
[
img_id
].
cpu
().
numpy
().
tolist
()
num_polys_per_mask_single
=
num_polys_per_mask
[
img_id
].
cpu
().
numpy
().
tolist
()
for
img_id
in
range
(
len
(
polys
)):
polys_single
=
polys
[
img_id
]
polys_lens_single
=
poly_lens
[
img_id
].
tolist
()
polys_per_mask_single
=
polys_per_mask
[
img_id
].
tolist
()
split_
gt_
polys
=
mmcv
.
slice_list
(
gt_
polys_single
,
gt_
polys_lens_single
)
mask_polys
=
mmcv
.
slice_list
(
split_
gt_
polys
,
num_
polys_per_mask_single
)
split_polys
=
mmcv
.
slice_list
(
polys_single
,
polys_lens_single
)
mask_polys
=
mmcv
.
slice_list
(
split_polys
,
polys_per_mask_single
)
mask_polys_list
.
append
(
mask_polys
)
return
mask_polys_list
mmdet/models/detectors/__init__.py
View file @
427c8902
from
.base
import
BaseDetector
from
.rpn
import
RPN
from
.faster_rcnn
import
FasterRCNN
from
.mask_rcnn
import
MaskRCNN
__all__
=
[
'BaseDetector'
,
'RPN'
]
__all__
=
[
'BaseDetector'
,
'RPN'
,
'FasterRCNN'
,
'MaskRCNN'
]
mmdet/models/detectors/faster_rcnn.py
View file @
427c8902
from
.two_stage
import
TwoStageDetector
class
FasterRCNN
(
TwoStageDetector
):
def
__init__
(
self
,
backbone
,
neck
,
rpn_head
,
bbox_roi_extractor
,
bbox_head
,
train_cfg
,
test_cfg
,
pretrained
=
None
):
super
(
FasterRCNN
,
self
).
__init__
(
backbone
=
backbone
,
neck
=
neck
,
rpn_head
=
rpn_head
,
bbox_roi_extractor
=
bbox_roi_extractor
,
bbox_head
=
bbox_head
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
,
pretrained
=
pretrained
)
mmdet/models/detectors/mask_rcnn.py
View file @
427c8902
from
.two_stage
import
TwoStageDetector
class
MaskRCNN
(
TwoStageDetector
):
def
__init__
(
self
,
backbone
,
neck
,
rpn_head
,
bbox_roi_extractor
,
bbox_head
,
mask_roi_extractor
,
mask_head
,
train_cfg
,
test_cfg
,
pretrained
=
None
):
super
(
MaskRCNN
,
self
).
__init__
(
backbone
=
backbone
,
neck
=
neck
,
rpn_head
=
rpn_head
,
bbox_roi_extractor
=
bbox_roi_extractor
,
bbox_head
=
bbox_head
,
mask_roi_extractor
=
mask_roi_extractor
,
mask_head
=
mask_head
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
,
pretrained
=
pretrained
)
mmdet/models/detectors/two_stage.py
View file @
427c8902
import
torch
import
torch.nn
as
nn
from
.base
import
Detector
from
.test
ing
_mixins
import
RPNTestMixin
,
BBoxTestMixin
from
.base
import
Base
Detector
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
..
import
builder
from
mmdet.core
import
bbox2roi
,
bbox2result
,
s
ample_proposals
from
mmdet.core
import
bbox2roi
,
bbox2result
,
s
plit_combined_polys
,
multi_apply
class
TwoStageDetector
(
Detector
,
RPNTestMixin
,
BBoxTestMixin
):
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
):
def
__init__
(
self
,
backbone
,
...
...
@@ -15,13 +16,16 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
rpn_head
=
None
,
bbox_roi_extractor
=
None
,
bbox_head
=
None
,
mask_roi_extractor
=
None
,
mask_head
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
):
super
(
Detector
,
self
).
__init__
()
super
(
TwoStage
Detector
,
self
).
__init__
()
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
with_neck
=
True
if
neck
is
not
None
else
False
assert
self
.
with_neck
,
"TwoStageDetector must be implemented with FPN now."
if
self
.
with_neck
:
self
.
neck
=
builder
.
build_neck
(
neck
)
...
...
@@ -35,6 +39,12 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
bbox_roi_extractor
)
self
.
bbox_head
=
builder
.
build_bbox_head
(
bbox_head
)
self
.
with_mask
=
True
if
mask_head
is
not
None
else
False
if
self
.
with_mask
:
self
.
mask_roi_extractor
=
builder
.
build_roi_extractor
(
mask_roi_extractor
)
self
.
mask_head
=
builder
.
build_mask_head
(
mask_head
)
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
...
...
@@ -68,6 +78,7 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
gt_masks
=
None
,
proposals
=
None
):
losses
=
dict
()
...
...
@@ -80,54 +91,73 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
rpn_losses
=
self
.
rpn_head
.
loss
(
*
rpn_loss_inputs
)
losses
.
update
(
rpn_losses
)
proposal_inputs
=
rpn_outs
+
(
img_meta
,
self
.
self
.
test_cfg
.
rpn
)
proposal_inputs
=
rpn_outs
+
(
img_meta
,
self
.
test_cfg
.
rpn
)
proposal_list
=
self
.
rpn_head
.
get_proposals
(
*
proposal_inputs
)
else
:
proposal_list
=
proposals
(
pos_inds
,
neg_inds
,
pos_proposals
,
neg_proposals
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
pos_gt_labels
)
=
sample_proposals
(
proposal_list
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
self
.
train_cfg
.
rcnn
)
labels
,
label_weights
,
bbox_targets
,
bbox_weights
=
\
self
.
bbox_head
.
get_bbox_target
(
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_labels
,
if
self
.
with_bbox
:
rcnn_train_cfg_list
=
[
self
.
train_cfg
.
rcnn
for
_
in
range
(
len
(
proposal_list
))
]
(
pos_proposals
,
neg_proposals
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
pos_gt_labels
)
=
multi_apply
(
self
.
bbox_roi_extractor
.
sample_proposals
,
proposal_list
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
rcnn_train_cfg_list
)
labels
,
label_weights
,
bbox_targets
,
bbox_weights
=
\
self
.
bbox_head
.
get_bbox_target
(
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_labels
,
self
.
train_cfg
.
rcnn
)
rois
=
bbox2roi
([
torch
.
cat
([
pos
,
neg
],
dim
=
0
)
for
pos
,
neg
in
zip
(
pos_proposals
,
neg_proposals
)
])
# TODO: a more flexible way to configurate feat maps
roi_feats
=
self
.
bbox_roi_extractor
(
x
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
bbox_pred
=
self
.
bbox_head
(
roi_feats
)
loss_bbox
=
self
.
bbox_head
.
loss
(
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
)
losses
.
update
(
loss_bbox
)
if
self
.
with_mask
:
gt_polys
=
split_combined_polys
(
**
gt_masks
)
mask_targets
=
self
.
mask_head
.
get_mask_target
(
pos_proposals
,
pos_assigned_gt_inds
,
gt_polys
,
img_meta
,
self
.
train_cfg
.
rcnn
)
rois
=
bbox2roi
([
torch
.
cat
([
pos
,
neg
],
dim
=
0
)
for
pos
,
neg
in
zip
(
pos_proposals
,
neg_proposals
)
])
# TODO: a more flexible way to configurate feat maps
roi_feats
=
self
.
bbox_roi_extractor
(
x
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
bbox_pred
=
self
.
bbox_head
(
roi_feats
)
loss_bbox
=
self
.
bbox_head
.
loss
(
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
)
losses
.
update
(
loss_bbox
)
pos_rois
=
bbox2roi
(
pos_proposals
)
mask_feats
=
self
.
mask_roi_extractor
(
x
[:
self
.
mask_roi_extractor
.
num_inputs
],
pos_rois
)
mask_pred
=
self
.
mask_head
(
mask_feats
)
loss_mask
=
self
.
mask_head
.
loss
(
mask_pred
,
mask_targets
,
torch
.
cat
(
pos_gt_labels
))
losses
.
update
(
loss_mask
)
return
losses
def
simple_test
(
self
,
img
,
img_meta
,
proposals
=
None
,
rescale
=
False
):
"""Test without augmentation."""
assert
proposals
==
None
,
"Fast RCNN hasn't been implemented."
assert
self
.
with_bbox
,
"Bbox head must be implemented."
x
=
self
.
extract_feat
(
img
)
if
proposals
is
None
:
proposals
=
self
.
simple_test_rpn
(
x
,
img_meta
)
if
self
.
with_bbox
:
# BUG proposals shape?
det_bboxes
,
det_labels
=
self
.
simple_test_bboxes
(
x
,
img_meta
,
[
proposals
],
rescale
=
rescale
)
bbox_result
=
bbox2result
(
det_bboxes
,
det_labels
,
self
.
bbox_head
.
num_classes
)
return
bbox_result
proposal_list
=
self
.
simple_test_rpn
(
x
,
img_meta
,
self
.
test_cfg
.
rpn
)
if
proposals
is
None
else
proposals
det_bboxes
,
det_labels
=
self
.
simple_test_bboxes
(
x
,
img_meta
,
proposal_list
,
self
.
test_cfg
.
rcnn
,
rescale
=
rescale
)
bbox_results
=
bbox2result
(
det_bboxes
,
det_labels
,
self
.
bbox_head
.
num_classes
)
if
self
.
with_mask
:
segm_results
=
self
.
simple_test_mask
(
x
,
img_meta
,
det_bboxes
,
det_labels
,
rescale
=
rescale
)
return
bbox_results
,
segm_results
else
:
proposals
[:,
:
4
]
/=
img_meta
[
'scale_factor'
].
float
()
return
proposals
.
cpu
().
numpy
()
return
bbox_results
def
aug_test
(
self
,
imgs
,
img_metas
,
rescale
=
False
):
"""Test with augmentations.
...
...
@@ -135,15 +165,28 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
proposals
=
self
.
aug_test_rpn
(
self
.
extract_feats
(
imgs
),
img_metas
,
self
.
rpn_test_cfg
)
# recompute self.extract_feats(imgs) because of 'yield' and memory
proposal_list
=
self
.
aug_test_rpn
(
self
.
extract_feats
(
imgs
),
img_metas
,
self
.
test_cfg
.
rpn
)
det_bboxes
,
det_labels
=
self
.
aug_test_bboxes
(
self
.
extract_feats
(
imgs
),
img_metas
,
proposals
,
self
.
rcnn_test_cfg
)
self
.
extract_feats
(
imgs
),
img_metas
,
proposal_list
,
self
.
test_cfg
.
rcnn
)
if
rescale
:
_det_bboxes
=
det_bboxes
else
:
_det_bboxes
=
det_bboxes
.
clone
()
_det_bboxes
[:,
:
4
]
*=
img_metas
[
0
][
'shape_scale'
][
0
][
-
1
]
bbox_result
=
bbox2result
(
_det_bboxes
,
det_labels
,
self
.
bbox_head
.
num_classes
)
return
bbox_result
_det_bboxes
[:,
:
4
]
*=
img_metas
[
0
][
0
][
'scale_factor'
]
bbox_results
=
bbox2result
(
_det_bboxes
,
det_labels
,
self
.
bbox_head
.
num_classes
)
# det_bboxes always keep the original scale
if
self
.
with_mask
:
segm_results
=
self
.
aug_test_mask
(
self
.
extract_feats
(
imgs
),
img_metas
,
det_bboxes
,
det_labels
)
return
bbox_results
,
segm_results
else
:
return
bbox_results
mmdet/models/mask_heads/fcn_mask_head.py
View file @
427c8902
...
...
@@ -93,11 +93,13 @@ class FCNMaskHead(nn.Module):
return
mask_targets
def
loss
(
self
,
mask_pred
,
mask_targets
,
labels
):
loss
=
dict
()
loss_mask
=
mask_cross_entropy
(
mask_pred
,
mask_targets
,
labels
)
return
loss_mask
loss
[
'loss_mask'
]
=
loss_mask
return
loss
def
get_seg_masks
(
self
,
mask_pred
,
det_bboxes
,
det_labels
,
rcnn_test_cfg
,
ori_s
cal
e
):
ori_s
hap
e
):
"""Get segmentation masks from mask_pred and bboxes
Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
...
...
@@ -108,7 +110,7 @@ class FCNMaskHead(nn.Module):
det_labels (Tensor): shape (n, )
img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config
rescale (bool): whether rescale masks to
original image size
ori_shape:
original image size
Returns:
list[list]: encoded masks
"""
...
...
@@ -118,8 +120,8 @@ class FCNMaskHead(nn.Module):
cls_segms
=
[[]
for
_
in
range
(
self
.
num_classes
-
1
)]
bboxes
=
det_bboxes
.
cpu
().
numpy
()[:,
:
4
]
labels
=
det_labels
.
cpu
().
numpy
()
+
1
img_h
=
ori_s
cal
e
[
0
]
img_w
=
ori_s
cal
e
[
1
]
img_h
=
ori_s
hap
e
[
0
]
img_w
=
ori_s
hap
e
[
1
]
for
i
in
range
(
bboxes
.
shape
[
0
]):
bbox
=
bboxes
[
i
,
:].
astype
(
int
)
...
...
mmdet/models/roi_extractors/single_level.py
View file @
427c8902
...
...
@@ -4,6 +4,7 @@ import torch
import
torch.nn
as
nn
from
mmdet
import
ops
from
mmdet.core
import
bbox_assign
,
bbox_sampling
class
SingleLevelRoI
(
nn
.
Module
):
...
...
@@ -51,6 +52,36 @@ class SingleLevelRoI(nn.Module):
target_lvls
=
target_lvls
.
clamp
(
min
=
0
,
max
=
num_levels
-
1
).
long
()
return
target_lvls
def
sample_proposals
(
self
,
proposals
,
gt_bboxes
,
gt_crowds
,
gt_labels
,
cfg
):
proposals
=
proposals
[:,
:
4
]
assigned_gt_inds
,
assigned_labels
,
argmax_overlaps
,
max_overlaps
=
\
bbox_assign
(
proposals
,
gt_bboxes
,
gt_crowds
,
gt_labels
,
cfg
.
pos_iou_thr
,
cfg
.
neg_iou_thr
,
cfg
.
pos_iou_thr
,
cfg
.
crowd_thr
)
if
cfg
.
add_gt_as_proposals
:
proposals
=
torch
.
cat
([
gt_bboxes
,
proposals
],
dim
=
0
)
gt_assign_self
=
torch
.
arange
(
1
,
len
(
gt_labels
)
+
1
,
dtype
=
torch
.
long
,
device
=
proposals
.
device
)
assigned_gt_inds
=
torch
.
cat
([
gt_assign_self
,
assigned_gt_inds
])
assigned_labels
=
torch
.
cat
([
gt_labels
,
assigned_labels
])
pos_inds
,
neg_inds
=
bbox_sampling
(
assigned_gt_inds
,
cfg
.
roi_batch_size
,
cfg
.
pos_fraction
,
cfg
.
neg_pos_ub
,
cfg
.
pos_balance_sampling
,
max_overlaps
,
cfg
.
neg_balance_thr
)
pos_proposals
=
proposals
[
pos_inds
]
neg_proposals
=
proposals
[
neg_inds
]
pos_assigned_gt_inds
=
assigned_gt_inds
[
pos_inds
]
-
1
pos_gt_bboxes
=
gt_bboxes
[
pos_assigned_gt_inds
,
:]
pos_gt_labels
=
assigned_labels
[
pos_inds
]
return
(
pos_proposals
,
neg_proposals
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
pos_gt_labels
)
def
forward
(
self
,
feats
,
rois
):
"""Extract roi features with the roi layer. If multiple feature levels
are used, then rois are mapped to corresponding levels according to
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment