Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
2017c81e
Commit
2017c81e
authored
Dec 20, 2018
by
Kai Chen
Browse files
Merge branch 'master' into pytorch-1.0
parents
c4408812
6594f862
Changes
48
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
334 additions
and
80 deletions
+334
-80
mmdet/core/bbox/samplers/base_sampler.py
mmdet/core/bbox/samplers/base_sampler.py
+22
-8
mmdet/core/bbox/samplers/combined_sampler.py
mmdet/core/bbox/samplers/combined_sampler.py
+12
-10
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
+1
-1
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
+1
-1
mmdet/core/bbox/samplers/ohem_sampler.py
mmdet/core/bbox/samplers/ohem_sampler.py
+68
-0
mmdet/core/bbox/samplers/pseudo_sampler.py
mmdet/core/bbox/samplers/pseudo_sampler.py
+4
-4
mmdet/core/bbox/samplers/random_sampler.py
mmdet/core/bbox/samplers/random_sampler.py
+6
-8
mmdet/core/evaluation/__init__.py
mmdet/core/evaluation/__init__.py
+2
-2
mmdet/core/evaluation/class_names.py
mmdet/core/evaluation/class_names.py
+8
-8
mmdet/core/evaluation/eval_hooks.py
mmdet/core/evaluation/eval_hooks.py
+39
-0
mmdet/core/loss/losses.py
mmdet/core/loss/losses.py
+6
-2
mmdet/datasets/__init__.py
mmdet/datasets/__init__.py
+5
-3
mmdet/datasets/coco.py
mmdet/datasets/coco.py
+15
-0
mmdet/datasets/concat_dataset.py
mmdet/datasets/concat_dataset.py
+8
-6
mmdet/datasets/custom.py
mmdet/datasets/custom.py
+4
-2
mmdet/datasets/repeat_dataset.py
mmdet/datasets/repeat_dataset.py
+5
-3
mmdet/datasets/voc.py
mmdet/datasets/voc.py
+18
-0
mmdet/datasets/xml_style.py
mmdet/datasets/xml_style.py
+76
-0
mmdet/models/backbones/__init__.py
mmdet/models/backbones/__init__.py
+2
-1
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+32
-21
No files found.
mmdet/core/bbox/samplers/base_sampler.py
View file @
2017c81e
...
...
@@ -7,19 +7,33 @@ from .sampling_result import SamplingResult
class
BaseSampler
(
metaclass
=
ABCMeta
):
def
__init__
(
self
):
def
__init__
(
self
,
num
,
pos_fraction
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
**
kwargs
):
self
.
num
=
num
self
.
pos_fraction
=
pos_fraction
self
.
neg_pos_ub
=
neg_pos_ub
self
.
add_gt_as_proposals
=
add_gt_as_proposals
self
.
pos_sampler
=
self
self
.
neg_sampler
=
self
@
abstractmethod
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
pass
@
abstractmethod
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
pass
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
):
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
,
**
kwargs
):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
...
...
@@ -44,8 +58,8 @@ class BaseSampler(metaclass=ABCMeta):
gt_flags
=
torch
.
cat
([
gt_ones
,
gt_flags
])
num_expected_pos
=
int
(
self
.
num
*
self
.
pos_fraction
)
pos_inds
=
self
.
pos_sampler
.
_sample_pos
(
assign_result
,
num_expected_po
s
)
pos_inds
=
self
.
pos_sampler
.
_sample_pos
(
assign_result
,
num_expected_pos
,
bboxes
=
bboxes
,
**
kwarg
s
)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds
=
pos_inds
.
unique
()
...
...
@@ -56,8 +70,8 @@ class BaseSampler(metaclass=ABCMeta):
neg_upper_bound
=
int
(
self
.
neg_pos_ub
*
_pos
)
if
num_expected_neg
>
neg_upper_bound
:
num_expected_neg
=
neg_upper_bound
neg_inds
=
self
.
neg_sampler
.
_sample_neg
(
assign_result
,
num_expected_neg
)
neg_inds
=
self
.
neg_sampler
.
_sample_neg
(
assign_result
,
num_expected_neg
,
bboxes
=
bboxes
,
**
kwargs
)
neg_inds
=
neg_inds
.
unique
()
return
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
...
...
mmdet/core/bbox/samplers/combined_sampler.py
View file @
2017c81e
from
.
random
_sampler
import
Random
Sampler
from
.
base
_sampler
import
Base
Sampler
from
..assign_sampling
import
build_sampler
class
CombinedSampler
(
Random
Sampler
):
class
CombinedSampler
(
Base
Sampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
pos_sampler
,
neg_sampler
,
**
kwargs
):
super
(
CombinedSampler
,
self
).
__init__
(
num
,
pos_fraction
,
**
kwargs
)
default_args
=
dict
(
num
=
num
,
pos_fraction
=
pos_fraction
)
default_args
.
update
(
kwargs
)
self
.
pos_sampler
=
build_sampler
(
pos_sampler
,
default_args
=
default_args
)
self
.
neg_sampler
=
build_sampler
(
neg_sampler
,
default_args
=
default_args
)
def
__init__
(
self
,
pos_sampler
,
neg_sampler
,
**
kwargs
):
super
(
CombinedSampler
,
self
).
__init__
(
**
kwargs
)
self
.
pos_sampler
=
build_sampler
(
pos_sampler
,
**
kwargs
)
self
.
neg_sampler
=
build_sampler
(
neg_sampler
,
**
kwargs
)
def
_sample_pos
(
self
,
**
kwargs
):
raise
NotImplementedError
def
_sample_neg
(
self
,
**
kwargs
):
raise
NotImplementedError
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
View file @
2017c81e
...
...
@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
class
InstanceBalancedPosSampler
(
RandomSampler
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
...
...
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
View file @
2017c81e
...
...
@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
self
.
hard_thr
=
hard_thr
self
.
hard_fraction
=
hard_fraction
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
...
...
mmdet/core/bbox/samplers/ohem_sampler.py
0 → 100644
View file @
2017c81e
import
torch
from
.base_sampler
import
BaseSampler
from
..transforms
import
bbox2roi
class
OHEMSampler
(
BaseSampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
context
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
**
kwargs
):
super
(
OHEMSampler
,
self
).
__init__
(
num
,
pos_fraction
,
neg_pos_ub
,
add_gt_as_proposals
)
self
.
bbox_roi_extractor
=
context
.
bbox_roi_extractor
self
.
bbox_head
=
context
.
bbox_head
def
hard_mining
(
self
,
inds
,
num_expected
,
bboxes
,
labels
,
feats
):
with
torch
.
no_grad
():
rois
=
bbox2roi
([
bboxes
])
bbox_feats
=
self
.
bbox_roi_extractor
(
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
loss
=
self
.
bbox_head
.
loss
(
cls_score
=
cls_score
,
bbox_pred
=
None
,
labels
=
labels
,
label_weights
=
cls_score
.
new_ones
(
cls_score
.
size
(
0
)),
bbox_targets
=
None
,
bbox_weights
=
None
,
reduce
=
False
)[
'loss_cls'
]
_
,
topk_loss_inds
=
loss
.
topk
(
num_expected
)
return
inds
[
topk_loss_inds
]
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
feats
=
None
,
**
kwargs
):
# Sample some hard positive samples
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
else
:
return
self
.
hard_mining
(
pos_inds
,
num_expected
,
bboxes
[
pos_inds
],
assign_result
.
labels
[
pos_inds
],
feats
)
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
feats
=
None
,
**
kwargs
):
# Sample some hard negative samples
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
else
:
return
self
.
hard_mining
(
neg_inds
,
num_expected
,
bboxes
[
neg_inds
],
assign_result
.
labels
[
neg_inds
],
feats
)
mmdet/core/bbox/samplers/pseudo_sampler.py
View file @
2017c81e
...
...
@@ -6,16 +6,16 @@ from .sampling_result import SamplingResult
class
PseudoSampler
(
BaseSampler
):
def
__init__
(
self
):
def
__init__
(
self
,
**
kwargs
):
pass
def
_sample_pos
(
self
):
def
_sample_pos
(
self
,
**
kwargs
):
raise
NotImplementedError
def
_sample_neg
(
self
):
def
_sample_neg
(
self
,
**
kwargs
):
raise
NotImplementedError
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
):
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
**
kwargs
):
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
).
squeeze
(
-
1
).
unique
()
neg_inds
=
torch
.
nonzero
(
...
...
mmdet/core/bbox/samplers/random_sampler.py
View file @
2017c81e
...
...
@@ -10,12 +10,10 @@ class RandomSampler(BaseSampler):
num
,
pos_fraction
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
):
super
(
RandomSampler
,
self
).
__init__
()
self
.
num
=
num
self
.
pos_fraction
=
pos_fraction
self
.
neg_pos_ub
=
neg_pos_ub
self
.
add_gt_as_proposals
=
add_gt_as_proposals
add_gt_as_proposals
=
True
,
**
kwargs
):
super
(
RandomSampler
,
self
).
__init__
(
num
,
pos_fraction
,
neg_pos_ub
,
add_gt_as_proposals
)
@
staticmethod
def
random_choice
(
gallery
,
num
):
...
...
@@ -34,7 +32,7 @@ class RandomSampler(BaseSampler):
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
return
gallery
[
rand_inds
]
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
"""Randomly sample some positive samples."""
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
...
...
@@ -44,7 +42,7 @@ class RandomSampler(BaseSampler):
else
:
return
self
.
random_choice
(
pos_inds
,
num_expected
)
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
"""Randomly sample some negative samples."""
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
...
...
mmdet/core/evaluation/__init__.py
View file @
2017c81e
...
...
@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes,
imagenet_vid_classes
,
coco_classes
,
dataset_aliases
,
get_classes
)
from
.coco_utils
import
coco_eval
,
fast_eval_recall
,
results2json
from
.eval_hooks
import
(
DistEvalHook
,
CocoDistEvalRecallHook
,
from
.eval_hooks
import
(
DistEvalHook
,
DistEvalmAPHook
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
.mean_ap
import
average_precision
,
eval_map
,
print_map_summary
from
.recall
import
(
eval_recalls
,
print_recall_summary
,
plot_num_recall
,
...
...
@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
__all__
=
[
'voc_classes'
,
'imagenet_det_classes'
,
'imagenet_vid_classes'
,
'coco_classes'
,
'dataset_aliases'
,
'get_classes'
,
'coco_eval'
,
'fast_eval_recall'
,
'results2json'
,
'DistEvalHook'
,
'fast_eval_recall'
,
'results2json'
,
'DistEvalHook'
,
'DistEvalmAPHook'
,
'CocoDistEvalRecallHook'
,
'CocoDistEvalmAPHook'
,
'average_precision'
,
'eval_map'
,
'print_map_summary'
,
'eval_recalls'
,
'print_recall_summary'
,
'plot_num_recall'
,
'plot_iou_recall'
...
...
mmdet/core/evaluation/class_names.py
View file @
2017c81e
...
...
@@ -63,18 +63,18 @@ def imagenet_vid_classes():
def
coco_classes
():
return
[
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic
light'
,
'fire
hydrant'
,
'stop
sign'
,
'parking
meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'truck'
,
'boat'
,
'traffic
_
light'
,
'fire
_
hydrant'
,
'stop
_
sign'
,
'parking
_
meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports
ball'
,
'kite'
,
'baseball
bat'
,
'baseball
glove'
,
'skateboard'
,
'surfboard'
,
'tennis
racket'
,
'bottle'
,
'wine
glass'
,
'cup'
,
'fork'
,
'sports
_
ball'
,
'kite'
,
'baseball
_
bat'
,
'baseball
_
glove'
,
'skateboard'
,
'surfboard'
,
'tennis
_
racket'
,
'bottle'
,
'wine
_
glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot
dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted
plant'
,
'bed'
,
'dining
table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell
phone'
,
'microwave'
,
'broccoli'
,
'carrot'
,
'hot
_
dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted
_
plant'
,
'bed'
,
'dining
_
table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell
_
phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy
bear'
,
'hair
drier'
,
'toothbrush'
'scissors'
,
'teddy
_
bear'
,
'hair
_
drier'
,
'toothbrush'
]
...
...
mmdet/core/evaluation/eval_hooks.py
View file @
2017c81e
...
...
@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval
from
torch.utils.data
import
Dataset
from
.coco_utils
import
results2json
,
fast_eval_recall
from
.mean_ap
import
eval_map
from
mmdet
import
datasets
...
...
@@ -102,6 +103,44 @@ class DistEvalHook(Hook):
raise
NotImplementedError
class
DistEvalmAPHook
(
DistEvalHook
):
def
evaluate
(
self
,
runner
,
results
):
gt_bboxes
=
[]
gt_labels
=
[]
gt_ignore
=
[]
if
self
.
dataset
.
with_crowd
else
None
for
i
in
range
(
len
(
self
.
dataset
)):
ann
=
self
.
dataset
.
get_ann_info
(
i
)
bboxes
=
ann
[
'bboxes'
]
labels
=
ann
[
'labels'
]
if
gt_ignore
is
not
None
:
ignore
=
np
.
concatenate
([
np
.
zeros
(
bboxes
.
shape
[
0
],
dtype
=
np
.
bool
),
np
.
ones
(
ann
[
'bboxes_ignore'
].
shape
[
0
],
dtype
=
np
.
bool
)
])
gt_ignore
.
append
(
ignore
)
bboxes
=
np
.
vstack
([
bboxes
,
ann
[
'bboxes_ignore'
]])
labels
=
np
.
concatenate
([
labels
,
ann
[
'labels_ignore'
]])
gt_bboxes
.
append
(
bboxes
)
gt_labels
.
append
(
labels
)
# If the dataset is VOC2007, then use 11 points mAP evaluation.
if
hasattr
(
self
.
dataset
,
'year'
)
and
self
.
dataset
.
year
==
2007
:
ds_name
=
'voc07'
else
:
ds_name
=
self
.
dataset
.
CLASSES
mean_ap
,
eval_results
=
eval_map
(
results
,
gt_bboxes
,
gt_labels
,
gt_ignore
=
gt_ignore
,
scale_ranges
=
None
,
iou_thr
=
0.5
,
dataset
=
ds_name
,
print_summary
=
True
)
runner
.
log_buffer
.
output
[
'mAP'
]
=
mean_ap
runner
.
log_buffer
.
ready
=
True
class
CocoDistEvalRecallHook
(
DistEvalHook
):
def
__init__
(
self
,
...
...
mmdet/core/loss/losses.py
View file @
2017c81e
...
...
@@ -10,11 +10,15 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
def
weighted_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
):
def
weighted_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
,
reduce
=
True
):
if
avg_factor
is
None
:
avg_factor
=
max
(
torch
.
sum
(
weight
>
0
).
float
().
item
(),
1.
)
raw
=
F
.
cross_entropy
(
pred
,
label
,
reduction
=
'none'
)
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
if
reduce
:
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
else
:
return
raw
*
weight
/
avg_factor
def
weighted_binary_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
):
...
...
mmdet/datasets/__init__.py
View file @
2017c81e
from
.custom
import
CustomDataset
from
.xml_style
import
XMLDataset
from
.coco
import
CocoDataset
from
.voc
import
VOCDataset
from
.loader
import
GroupSampler
,
DistributedGroupSampler
,
build_dataloader
from
.utils
import
to_tensor
,
random_scale
,
show_ann
,
get_dataset
from
.concat_dataset
import
ConcatDataset
from
.repeat_dataset
import
RepeatDataset
__all__
=
[
'CustomDataset'
,
'CocoDataset'
,
'
GroupSampler'
,
'Distributed
GroupSampler'
,
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'show_ann'
,
'get_dataset'
,
'ConcatDataset'
,
'RepeatDataset'
,
'CustomDataset'
,
'XMLDataset'
,
'CocoDataset'
,
'
VOCDataset'
,
'
GroupSampler'
,
'DistributedGroupSampler'
,
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'show_ann'
,
'get_dataset'
,
'ConcatDataset'
,
'RepeatDataset'
]
mmdet/datasets/coco.py
View file @
2017c81e
...
...
@@ -6,6 +6,21 @@ from .custom import CustomDataset
class
CocoDataset
(
CustomDataset
):
CLASSES
=
(
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic_light'
,
'fire_hydrant'
,
'stop_sign'
,
'parking_meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports_ball'
,
'kite'
,
'baseball_bat'
,
'baseball_glove'
,
'skateboard'
,
'surfboard'
,
'tennis_racket'
,
'bottle'
,
'wine_glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot_dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted_plant'
,
'bed'
,
'dining_table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell_phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy_bear'
,
'hair_drier'
,
'toothbrush'
)
def
load_annotations
(
self
,
ann_file
):
self
.
coco
=
COCO
(
ann_file
)
self
.
cat_ids
=
self
.
coco
.
getCatIds
()
...
...
mmdet/datasets/concat_dataset.py
View file @
2017c81e
...
...
@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class
ConcatDataset
(
_ConcatDataset
):
"""
Same as torch.utils.data.dataset.ConcatDataset, but
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
"""
def
__init__
(
self
,
datasets
):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super
(
ConcatDataset
,
self
).
__init__
(
datasets
)
self
.
CLASSES
=
datasets
[
0
].
CLASSES
if
hasattr
(
datasets
[
0
],
'flag'
):
flags
=
[]
for
i
in
range
(
0
,
len
(
datasets
)):
...
...
mmdet/datasets/custom.py
View file @
2017c81e
...
...
@@ -32,6 +32,8 @@ class CustomDataset(Dataset):
The `ann` field is optional for testing.
"""
CLASSES
=
None
def
__init__
(
self
,
ann_file
,
img_prefix
,
...
...
@@ -45,6 +47,8 @@ class CustomDataset(Dataset):
with_crowd
=
True
,
with_label
=
True
,
test_mode
=
False
):
# prefix of images path
self
.
img_prefix
=
img_prefix
# load annotations (and proposals)
self
.
img_infos
=
self
.
load_annotations
(
ann_file
)
if
proposal_file
is
not
None
:
...
...
@@ -58,8 +62,6 @@ class CustomDataset(Dataset):
if
self
.
proposals
is
not
None
:
self
.
proposals
=
[
self
.
proposals
[
i
]
for
i
in
valid_inds
]
# prefix of images path
self
.
img_prefix
=
img_prefix
# (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
self
.
img_scales
=
img_scale
if
isinstance
(
img_scale
,
list
)
else
[
img_scale
]
...
...
mmdet/datasets/repeat_dataset.py
View file @
2017c81e
...
...
@@ -6,12 +6,14 @@ class RepeatDataset(object):
def
__init__
(
self
,
dataset
,
times
):
self
.
dataset
=
dataset
self
.
times
=
times
self
.
CLASSES
=
dataset
.
CLASSES
if
hasattr
(
self
.
dataset
,
'flag'
):
self
.
flag
=
np
.
tile
(
self
.
dataset
.
flag
,
times
)
self
.
_original_length
=
len
(
self
.
dataset
)
self
.
_ori_len
=
len
(
self
.
dataset
)
def
__getitem__
(
self
,
idx
):
return
self
.
dataset
[
idx
%
self
.
_ori
ginal
_len
gth
]
return
self
.
dataset
[
idx
%
self
.
_ori_len
]
def
__len__
(
self
):
return
self
.
times
*
self
.
_ori
ginal
_len
gth
return
self
.
times
*
self
.
_ori_len
mmdet/datasets/voc.py
0 → 100644
View file @
2017c81e
from
.xml_style
import
XMLDataset
class
VOCDataset
(
XMLDataset
):
CLASSES
=
(
'aeroplane'
,
'bicycle'
,
'bird'
,
'boat'
,
'bottle'
,
'bus'
,
'car'
,
'cat'
,
'chair'
,
'cow'
,
'diningtable'
,
'dog'
,
'horse'
,
'motorbike'
,
'person'
,
'pottedplant'
,
'sheep'
,
'sofa'
,
'train'
,
'tvmonitor'
)
def
__init__
(
self
,
**
kwargs
):
super
(
VOCDataset
,
self
).
__init__
(
**
kwargs
)
if
'VOC2007'
in
self
.
img_prefix
:
self
.
year
=
2007
elif
'VOC2012'
in
self
.
img_prefix
:
self
.
year
=
2012
else
:
raise
ValueError
(
'Cannot infer dataset year from img_prefix'
)
mmdet/datasets/xml_style.py
0 → 100644
View file @
2017c81e
import
os.path
as
osp
import
xml.etree.ElementTree
as
ET
import
mmcv
import
numpy
as
np
from
.custom
import
CustomDataset
class
XMLDataset
(
CustomDataset
):
def
__init__
(
self
,
**
kwargs
):
super
(
XMLDataset
,
self
).
__init__
(
**
kwargs
)
self
.
cat2label
=
{
cat
:
i
+
1
for
i
,
cat
in
enumerate
(
self
.
CLASSES
)}
def
load_annotations
(
self
,
ann_file
):
img_infos
=
[]
img_ids
=
mmcv
.
list_from_file
(
ann_file
)
for
img_id
in
img_ids
:
filename
=
'JPEGImages/{}.jpg'
.
format
(
img_id
)
xml_path
=
osp
.
join
(
self
.
img_prefix
,
'Annotations'
,
'{}.xml'
.
format
(
img_id
))
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
size
=
root
.
find
(
'size'
)
width
=
int
(
size
.
find
(
'width'
).
text
)
height
=
int
(
size
.
find
(
'height'
).
text
)
img_infos
.
append
(
dict
(
id
=
img_id
,
filename
=
filename
,
width
=
width
,
height
=
height
))
return
img_infos
def
get_ann_info
(
self
,
idx
):
img_id
=
self
.
img_infos
[
idx
][
'id'
]
xml_path
=
osp
.
join
(
self
.
img_prefix
,
'Annotations'
,
'{}.xml'
.
format
(
img_id
))
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
bboxes
=
[]
labels
=
[]
bboxes_ignore
=
[]
labels_ignore
=
[]
for
obj
in
root
.
findall
(
'object'
):
name
=
obj
.
find
(
'name'
).
text
label
=
self
.
cat2label
[
name
]
difficult
=
int
(
obj
.
find
(
'difficult'
).
text
)
bnd_box
=
obj
.
find
(
'bndbox'
)
bbox
=
[
int
(
bnd_box
.
find
(
'xmin'
).
text
),
int
(
bnd_box
.
find
(
'ymin'
).
text
),
int
(
bnd_box
.
find
(
'xmax'
).
text
),
int
(
bnd_box
.
find
(
'ymax'
).
text
)
]
if
difficult
:
bboxes_ignore
.
append
(
bbox
)
labels_ignore
.
append
(
label
)
else
:
bboxes
.
append
(
bbox
)
labels
.
append
(
label
)
if
not
bboxes
:
bboxes
=
np
.
zeros
((
0
,
4
))
labels
=
np
.
zeros
((
0
,
))
else
:
bboxes
=
np
.
array
(
bboxes
,
ndmin
=
2
)
-
1
labels
=
np
.
array
(
labels
)
if
not
bboxes_ignore
:
bboxes_ignore
=
np
.
zeros
((
0
,
4
))
labels_ignore
=
np
.
zeros
((
0
,
))
else
:
bboxes_ignore
=
np
.
array
(
bboxes_ignore
,
ndmin
=
2
)
-
1
labels_ignore
=
np
.
array
(
labels_ignore
)
ann
=
dict
(
bboxes
=
bboxes
.
astype
(
np
.
float32
),
labels
=
labels
.
astype
(
np
.
int64
),
bboxes_ignore
=
bboxes_ignore
.
astype
(
np
.
float32
),
labels_ignore
=
labels_ignore
.
astype
(
np
.
int64
))
return
ann
mmdet/models/backbones/__init__.py
View file @
2017c81e
from
.resnet
import
ResNet
from
.resnext
import
ResNeXt
__all__
=
[
'ResNet'
]
__all__
=
[
'ResNet'
,
'ResNeXt'
]
mmdet/models/backbones/resnet.py
View file @
2017c81e
...
...
@@ -42,7 +42,7 @@ class BasicBlock(nn.Module):
assert
not
with_cp
def
forward
(
self
,
x
):
residual
=
x
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
...
...
@@ -52,9 +52,9 @@ class BasicBlock(nn.Module):
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
identity
=
self
.
downsample
(
x
)
out
+=
residual
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
...
...
@@ -71,25 +71,31 @@ class Bottleneck(nn.Module):
downsample
=
None
,
style
=
'pytorch'
,
with_cp
=
False
):
"""Bottleneck block.
"""Bottleneck block
for ResNet
.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super
(
Bottleneck
,
self
).
__init__
()
assert
style
in
[
'pytorch'
,
'caffe'
]
self
.
inplanes
=
inplanes
self
.
planes
=
planes
if
style
==
'pytorch'
:
conv1_stride
=
1
conv2_stride
=
stride
self
.
conv1_stride
=
1
self
.
conv2_stride
=
stride
else
:
conv1_stride
=
stride
conv2_stride
=
1
self
.
conv1_stride
=
stride
self
.
conv2_stride
=
1
self
.
conv1
=
nn
.
Conv2d
(
inplanes
,
planes
,
kernel_size
=
1
,
stride
=
conv1_stride
,
bias
=
False
)
inplanes
,
planes
,
kernel_size
=
1
,
stride
=
self
.
conv1_stride
,
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
planes
,
planes
,
kernel_size
=
3
,
stride
=
conv2_stride
,
stride
=
self
.
conv2_stride
,
padding
=
dilation
,
dilation
=
dilation
,
bias
=
False
)
...
...
@@ -108,7 +114,7 @@ class Bottleneck(nn.Module):
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
residual
=
x
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
...
...
@@ -122,9 +128,9 @@ class Bottleneck(nn.Module):
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
identity
=
self
.
downsample
(
x
)
out
+=
residual
out
+=
identity
return
out
...
...
@@ -219,20 +225,24 @@ class ResNet(nn.Module):
super
(
ResNet
,
self
).
__init__
()
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
'invalid depth {} for resnet'
.
format
(
depth
))
self
.
depth
=
depth
self
.
num_stages
=
num_stages
assert
num_stages
>=
1
and
num_stages
<=
4
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
s
tage_blocks
=
stage_blocks
[:
num_stages
]
self
.
strides
=
strides
s
elf
.
dilations
=
dilations
assert
len
(
strides
)
==
len
(
dilations
)
==
num_stages
assert
max
(
out_indices
)
<
num_stages
self
.
out_indices
=
out_indices
assert
max
(
out_indices
)
<
num_stages
self
.
style
=
style
self
.
frozen_stages
=
frozen_stages
self
.
bn_eval
=
bn_eval
self
.
bn_frozen
=
bn_frozen
self
.
with_cp
=
with_cp
self
.
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
self
.
stage_blocks
=
stage_blocks
[:
num_stages
]
self
.
inplanes
=
64
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
...
...
@@ -240,12 +250,12 @@ class ResNet(nn.Module):
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
res_layers
=
[]
for
i
,
num_blocks
in
enumerate
(
stage_blocks
):
for
i
,
num_blocks
in
enumerate
(
self
.
stage_blocks
):
stride
=
strides
[
i
]
dilation
=
dilations
[
i
]
planes
=
64
*
2
**
i
res_layer
=
make_res_layer
(
block
,
self
.
block
,
self
.
inplanes
,
planes
,
num_blocks
,
...
...
@@ -253,12 +263,13 @@ class ResNet(nn.Module):
dilation
=
dilation
,
style
=
self
.
style
,
with_cp
=
with_cp
)
self
.
inplanes
=
planes
*
block
.
expansion
self
.
inplanes
=
planes
*
self
.
block
.
expansion
layer_name
=
'layer{}'
.
format
(
i
+
1
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
res_layers
.
append
(
layer_name
)
self
.
feat_dim
=
block
.
expansion
*
64
*
2
**
(
len
(
stage_blocks
)
-
1
)
self
.
feat_dim
=
self
.
block
.
expansion
*
64
*
2
**
(
len
(
self
.
stage_blocks
)
-
1
)
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment