Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
2de84ef8
Commit
2de84ef8
authored
Dec 20, 2018
by
yhcao6
Browse files
resolve conflict
parents
c6a62868
6594f862
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
372 additions
and
59 deletions
+372
-59
mmdet/core/bbox/samplers/combined_sampler.py
mmdet/core/bbox/samplers/combined_sampler.py
+12
-10
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
+1
-1
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
+1
-1
mmdet/core/bbox/samplers/ohem_sampler.py
mmdet/core/bbox/samplers/ohem_sampler.py
+68
-0
mmdet/core/bbox/samplers/pseudo_sampler.py
mmdet/core/bbox/samplers/pseudo_sampler.py
+4
-4
mmdet/core/bbox/samplers/random_sampler.py
mmdet/core/bbox/samplers/random_sampler.py
+6
-8
mmdet/core/loss/losses.py
mmdet/core/loss/losses.py
+6
-2
mmdet/models/backbones/__init__.py
mmdet/models/backbones/__init__.py
+2
-1
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+32
-21
mmdet/models/backbones/resnext.py
mmdet/models/backbones/resnext.py
+149
-0
mmdet/models/bbox_heads/bbox_head.py
mmdet/models/bbox_heads/bbox_head.py
+2
-2
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+17
-8
mmdet/models/mask_heads/fcn_mask_head.py
mmdet/models/mask_heads/fcn_mask_head.py
+5
-1
tools/train.py
tools/train.py
+5
-0
tools/voc_eval.py
tools/voc_eval.py
+62
-0
No files found.
mmdet/core/bbox/samplers/combined_sampler.py
View file @
2de84ef8
from
.
random
_sampler
import
Random
Sampler
from
.
base
_sampler
import
Base
Sampler
from
..assign_sampling
import
build_sampler
from
..assign_sampling
import
build_sampler
class
CombinedSampler
(
Random
Sampler
):
class
CombinedSampler
(
Base
Sampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
pos_sampler
,
neg_sampler
,
**
kwargs
):
def
__init__
(
self
,
pos_sampler
,
neg_sampler
,
**
kwargs
):
super
(
CombinedSampler
,
self
).
__init__
(
num
,
pos_fraction
,
**
kwargs
)
super
(
CombinedSampler
,
self
).
__init__
(
**
kwargs
)
default_args
=
dict
(
num
=
num
,
pos_fraction
=
pos_fraction
)
self
.
pos_sampler
=
build_sampler
(
pos_sampler
,
**
kwargs
)
default_args
.
update
(
kwargs
)
self
.
neg_sampler
=
build_sampler
(
neg_sampler
,
**
kwargs
)
self
.
pos_sampler
=
build_sampler
(
pos_sampler
,
default_args
=
default_args
)
def
_sample_pos
(
self
,
**
kwargs
):
self
.
neg_sampler
=
build_sampler
(
raise
NotImplementedError
neg_sampler
,
default_args
=
default_args
)
def
_sample_neg
(
self
,
**
kwargs
):
raise
NotImplementedError
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
View file @
2de84ef8
...
@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
...
@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
class
InstanceBalancedPosSampler
(
RandomSampler
):
class
InstanceBalancedPosSampler
(
RandomSampler
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
pos_inds
=
pos_inds
.
squeeze
(
1
)
...
...
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
View file @
2de84ef8
...
@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
...
@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
self
.
hard_thr
=
hard_thr
self
.
hard_thr
=
hard_thr
self
.
hard_fraction
=
hard_fraction
self
.
hard_fraction
=
hard_fraction
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
neg_inds
=
neg_inds
.
squeeze
(
1
)
...
...
mmdet/core/bbox/samplers/ohem_sampler.py
0 → 100644
View file @
2de84ef8
import
torch
from
.base_sampler
import
BaseSampler
from
..transforms
import
bbox2roi
class
OHEMSampler
(
BaseSampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
context
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
**
kwargs
):
super
(
OHEMSampler
,
self
).
__init__
(
num
,
pos_fraction
,
neg_pos_ub
,
add_gt_as_proposals
)
self
.
bbox_roi_extractor
=
context
.
bbox_roi_extractor
self
.
bbox_head
=
context
.
bbox_head
def
hard_mining
(
self
,
inds
,
num_expected
,
bboxes
,
labels
,
feats
):
with
torch
.
no_grad
():
rois
=
bbox2roi
([
bboxes
])
bbox_feats
=
self
.
bbox_roi_extractor
(
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
loss
=
self
.
bbox_head
.
loss
(
cls_score
=
cls_score
,
bbox_pred
=
None
,
labels
=
labels
,
label_weights
=
cls_score
.
new_ones
(
cls_score
.
size
(
0
)),
bbox_targets
=
None
,
bbox_weights
=
None
,
reduce
=
False
)[
'loss_cls'
]
_
,
topk_loss_inds
=
loss
.
topk
(
num_expected
)
return
inds
[
topk_loss_inds
]
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
feats
=
None
,
**
kwargs
):
# Sample some hard positive samples
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
else
:
return
self
.
hard_mining
(
pos_inds
,
num_expected
,
bboxes
[
pos_inds
],
assign_result
.
labels
[
pos_inds
],
feats
)
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
feats
=
None
,
**
kwargs
):
# Sample some hard negative samples
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
else
:
return
self
.
hard_mining
(
neg_inds
,
num_expected
,
bboxes
[
neg_inds
],
assign_result
.
labels
[
neg_inds
],
feats
)
mmdet/core/bbox/samplers/pseudo_sampler.py
View file @
2de84ef8
...
@@ -6,16 +6,16 @@ from .sampling_result import SamplingResult
...
@@ -6,16 +6,16 @@ from .sampling_result import SamplingResult
class
PseudoSampler
(
BaseSampler
):
class
PseudoSampler
(
BaseSampler
):
def
__init__
(
self
):
def
__init__
(
self
,
**
kwargs
):
pass
pass
def
_sample_pos
(
self
):
def
_sample_pos
(
self
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
def
_sample_neg
(
self
):
def
_sample_neg
(
self
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
):
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
**
kwargs
):
pos_inds
=
torch
.
nonzero
(
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
).
squeeze
(
-
1
).
unique
()
assign_result
.
gt_inds
>
0
).
squeeze
(
-
1
).
unique
()
neg_inds
=
torch
.
nonzero
(
neg_inds
=
torch
.
nonzero
(
...
...
mmdet/core/bbox/samplers/random_sampler.py
View file @
2de84ef8
...
@@ -10,12 +10,10 @@ class RandomSampler(BaseSampler):
...
@@ -10,12 +10,10 @@ class RandomSampler(BaseSampler):
num
,
num
,
pos_fraction
,
pos_fraction
,
neg_pos_ub
=-
1
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
):
add_gt_as_proposals
=
True
,
super
(
RandomSampler
,
self
).
__init__
()
**
kwargs
):
self
.
num
=
num
super
(
RandomSampler
,
self
).
__init__
(
num
,
pos_fraction
,
neg_pos_ub
,
self
.
pos_fraction
=
pos_fraction
add_gt_as_proposals
)
self
.
neg_pos_ub
=
neg_pos_ub
self
.
add_gt_as_proposals
=
add_gt_as_proposals
@
staticmethod
@
staticmethod
def
random_choice
(
gallery
,
num
):
def
random_choice
(
gallery
,
num
):
...
@@ -34,7 +32,7 @@ class RandomSampler(BaseSampler):
...
@@ -34,7 +32,7 @@ class RandomSampler(BaseSampler):
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
return
gallery
[
rand_inds
]
return
gallery
[
rand_inds
]
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
"""Randomly sample some positive samples."""
"""Randomly sample some positive samples."""
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
if
pos_inds
.
numel
()
!=
0
:
...
@@ -44,7 +42,7 @@ class RandomSampler(BaseSampler):
...
@@ -44,7 +42,7 @@ class RandomSampler(BaseSampler):
else
:
else
:
return
self
.
random_choice
(
pos_inds
,
num_expected
)
return
self
.
random_choice
(
pos_inds
,
num_expected
)
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
"""Randomly sample some negative samples."""
"""Randomly sample some negative samples."""
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
if
neg_inds
.
numel
()
!=
0
:
...
...
mmdet/core/loss/losses.py
View file @
2de84ef8
...
@@ -10,11 +10,15 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
...
@@ -10,11 +10,15 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
def
weighted_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
):
def
weighted_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
,
reduce
=
True
):
if
avg_factor
is
None
:
if
avg_factor
is
None
:
avg_factor
=
max
(
torch
.
sum
(
weight
>
0
).
float
().
item
(),
1.
)
avg_factor
=
max
(
torch
.
sum
(
weight
>
0
).
float
().
item
(),
1.
)
raw
=
F
.
cross_entropy
(
pred
,
label
,
reduction
=
'none'
)
raw
=
F
.
cross_entropy
(
pred
,
label
,
reduction
=
'none'
)
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
if
reduce
:
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
else
:
return
raw
*
weight
/
avg_factor
def
weighted_binary_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
):
def
weighted_binary_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
):
...
...
mmdet/models/backbones/__init__.py
View file @
2de84ef8
from
.resnet
import
ResNet
from
.resnet
import
ResNet
from
.resnext
import
ResNeXt
from
.ssd_vgg
import
SSDVGG
from
.ssd_vgg
import
SSDVGG
__all__
=
[
'ResNet'
,
'SSDVGG'
]
__all__
=
[
'ResNet'
,
'ResNeXt'
,
'SSDVGG'
]
mmdet/models/backbones/resnet.py
View file @
2de84ef8
...
@@ -42,7 +42,7 @@ class BasicBlock(nn.Module):
...
@@ -42,7 +42,7 @@ class BasicBlock(nn.Module):
assert
not
with_cp
assert
not
with_cp
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
residual
=
x
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
bn1
(
out
)
...
@@ -52,9 +52,9 @@ class BasicBlock(nn.Module):
...
@@ -52,9 +52,9 @@ class BasicBlock(nn.Module):
out
=
self
.
bn2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
identity
=
self
.
downsample
(
x
)
out
+=
residual
out
+=
identity
out
=
self
.
relu
(
out
)
out
=
self
.
relu
(
out
)
return
out
return
out
...
@@ -71,25 +71,31 @@ class Bottleneck(nn.Module):
...
@@ -71,25 +71,31 @@ class Bottleneck(nn.Module):
downsample
=
None
,
downsample
=
None
,
style
=
'pytorch'
,
style
=
'pytorch'
,
with_cp
=
False
):
with_cp
=
False
):
"""Bottleneck block.
"""Bottleneck block
for ResNet
.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
"""
super
(
Bottleneck
,
self
).
__init__
()
super
(
Bottleneck
,
self
).
__init__
()
assert
style
in
[
'pytorch'
,
'caffe'
]
assert
style
in
[
'pytorch'
,
'caffe'
]
self
.
inplanes
=
inplanes
self
.
planes
=
planes
if
style
==
'pytorch'
:
if
style
==
'pytorch'
:
conv1_stride
=
1
self
.
conv1_stride
=
1
conv2_stride
=
stride
self
.
conv2_stride
=
stride
else
:
else
:
conv1_stride
=
stride
self
.
conv1_stride
=
stride
conv2_stride
=
1
self
.
conv2_stride
=
1
self
.
conv1
=
nn
.
Conv2d
(
self
.
conv1
=
nn
.
Conv2d
(
inplanes
,
planes
,
kernel_size
=
1
,
stride
=
conv1_stride
,
bias
=
False
)
inplanes
,
planes
,
kernel_size
=
1
,
stride
=
self
.
conv1_stride
,
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
self
.
conv2
=
nn
.
Conv2d
(
planes
,
planes
,
planes
,
planes
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
conv2_stride
,
stride
=
self
.
conv2_stride
,
padding
=
dilation
,
padding
=
dilation
,
dilation
=
dilation
,
dilation
=
dilation
,
bias
=
False
)
bias
=
False
)
...
@@ -108,7 +114,7 @@ class Bottleneck(nn.Module):
...
@@ -108,7 +114,7 @@ class Bottleneck(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
def
_inner_forward
(
x
):
residual
=
x
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
bn1
(
out
)
...
@@ -122,9 +128,9 @@ class Bottleneck(nn.Module):
...
@@ -122,9 +128,9 @@ class Bottleneck(nn.Module):
out
=
self
.
bn3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
identity
=
self
.
downsample
(
x
)
out
+=
residual
out
+=
identity
return
out
return
out
...
@@ -219,20 +225,24 @@ class ResNet(nn.Module):
...
@@ -219,20 +225,24 @@ class ResNet(nn.Module):
super
(
ResNet
,
self
).
__init__
()
super
(
ResNet
,
self
).
__init__
()
if
depth
not
in
self
.
arch_settings
:
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
'invalid depth {} for resnet'
.
format
(
depth
))
raise
KeyError
(
'invalid depth {} for resnet'
.
format
(
depth
))
self
.
depth
=
depth
self
.
num_stages
=
num_stages
assert
num_stages
>=
1
and
num_stages
<=
4
assert
num_stages
>=
1
and
num_stages
<=
4
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
self
.
strides
=
strides
s
tage_blocks
=
stage_blocks
[:
num_stages
]
s
elf
.
dilations
=
dilations
assert
len
(
strides
)
==
len
(
dilations
)
==
num_stages
assert
len
(
strides
)
==
len
(
dilations
)
==
num_stages
assert
max
(
out_indices
)
<
num_stages
self
.
out_indices
=
out_indices
self
.
out_indices
=
out_indices
assert
max
(
out_indices
)
<
num_stages
self
.
style
=
style
self
.
style
=
style
self
.
frozen_stages
=
frozen_stages
self
.
frozen_stages
=
frozen_stages
self
.
bn_eval
=
bn_eval
self
.
bn_eval
=
bn_eval
self
.
bn_frozen
=
bn_frozen
self
.
bn_frozen
=
bn_frozen
self
.
with_cp
=
with_cp
self
.
with_cp
=
with_cp
self
.
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
self
.
stage_blocks
=
stage_blocks
[:
num_stages
]
self
.
inplanes
=
64
self
.
inplanes
=
64
self
.
conv1
=
nn
.
Conv2d
(
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
...
@@ -240,12 +250,12 @@ class ResNet(nn.Module):
...
@@ -240,12 +250,12 @@ class ResNet(nn.Module):
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
res_layers
=
[]
self
.
res_layers
=
[]
for
i
,
num_blocks
in
enumerate
(
stage_blocks
):
for
i
,
num_blocks
in
enumerate
(
self
.
stage_blocks
):
stride
=
strides
[
i
]
stride
=
strides
[
i
]
dilation
=
dilations
[
i
]
dilation
=
dilations
[
i
]
planes
=
64
*
2
**
i
planes
=
64
*
2
**
i
res_layer
=
make_res_layer
(
res_layer
=
make_res_layer
(
block
,
self
.
block
,
self
.
inplanes
,
self
.
inplanes
,
planes
,
planes
,
num_blocks
,
num_blocks
,
...
@@ -253,12 +263,13 @@ class ResNet(nn.Module):
...
@@ -253,12 +263,13 @@ class ResNet(nn.Module):
dilation
=
dilation
,
dilation
=
dilation
,
style
=
self
.
style
,
style
=
self
.
style
,
with_cp
=
with_cp
)
with_cp
=
with_cp
)
self
.
inplanes
=
planes
*
block
.
expansion
self
.
inplanes
=
planes
*
self
.
block
.
expansion
layer_name
=
'layer{}'
.
format
(
i
+
1
)
layer_name
=
'layer{}'
.
format
(
i
+
1
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
res_layers
.
append
(
layer_name
)
self
.
res_layers
.
append
(
layer_name
)
self
.
feat_dim
=
block
.
expansion
*
64
*
2
**
(
len
(
stage_blocks
)
-
1
)
self
.
feat_dim
=
self
.
block
.
expansion
*
64
*
2
**
(
len
(
self
.
stage_blocks
)
-
1
)
def
init_weights
(
self
,
pretrained
=
None
):
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
if
isinstance
(
pretrained
,
str
):
...
...
mmdet/models/backbones/resnext.py
0 → 100644
View file @
2de84ef8
import
math
import
torch.nn
as
nn
from
.resnet
import
ResNet
from
.resnet
import
Bottleneck
as
_Bottleneck
class
Bottleneck
(
_Bottleneck
):
def
__init__
(
self
,
*
args
,
groups
=
1
,
base_width
=
4
,
**
kwargs
):
"""Bottleneck block for ResNeXt.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super
(
Bottleneck
,
self
).
__init__
(
*
args
,
**
kwargs
)
if
groups
==
1
:
width
=
self
.
planes
else
:
width
=
math
.
floor
(
self
.
planes
*
(
base_width
/
64
))
*
groups
self
.
conv1
=
nn
.
Conv2d
(
self
.
inplanes
,
width
,
kernel_size
=
1
,
stride
=
self
.
conv1_stride
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
width
)
self
.
conv2
=
nn
.
Conv2d
(
width
,
width
,
kernel_size
=
3
,
stride
=
self
.
conv2_stride
,
padding
=
self
.
dilation
,
dilation
=
self
.
dilation
,
groups
=
groups
,
bias
=
False
)
self
.
bn2
=
nn
.
BatchNorm2d
(
width
)
self
.
conv3
=
nn
.
Conv2d
(
width
,
self
.
planes
*
self
.
expansion
,
kernel_size
=
1
,
bias
=
False
)
self
.
bn3
=
nn
.
BatchNorm2d
(
self
.
planes
*
self
.
expansion
)
def
make_res_layer
(
block
,
inplanes
,
planes
,
blocks
,
stride
=
1
,
dilation
=
1
,
groups
=
1
,
base_width
=
4
,
style
=
'pytorch'
,
with_cp
=
False
):
downsample
=
None
if
stride
!=
1
or
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
nn
.
Conv2d
(
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
),
nn
.
BatchNorm2d
(
planes
*
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
block
(
inplanes
,
planes
,
stride
=
stride
,
dilation
=
dilation
,
downsample
=
downsample
,
groups
=
groups
,
base_width
=
base_width
,
style
=
style
,
with_cp
=
with_cp
))
inplanes
=
planes
*
block
.
expansion
for
i
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
inplanes
,
planes
,
stride
=
1
,
dilation
=
dilation
,
groups
=
groups
,
base_width
=
base_width
,
style
=
style
,
with_cp
=
with_cp
))
return
nn
.
Sequential
(
*
layers
)
class
ResNeXt
(
ResNet
):
"""ResNeXt backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
num_stages (int): Resnet stages, normally 4.
groups (int): Group of resnext.
base_width (int): Base width of resnext.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
arch_settings
=
{
50
:
(
Bottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottleneck
,
(
3
,
8
,
36
,
3
))
}
def
__init__
(
self
,
groups
=
1
,
base_width
=
4
,
**
kwargs
):
super
(
ResNeXt
,
self
).
__init__
(
**
kwargs
)
self
.
groups
=
groups
self
.
base_width
=
base_width
self
.
inplanes
=
64
self
.
res_layers
=
[]
for
i
,
num_blocks
in
enumerate
(
self
.
stage_blocks
):
stride
=
self
.
strides
[
i
]
dilation
=
self
.
dilations
[
i
]
planes
=
64
*
2
**
i
res_layer
=
make_res_layer
(
self
.
block
,
self
.
inplanes
,
planes
,
num_blocks
,
stride
=
stride
,
dilation
=
dilation
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
style
=
self
.
style
,
with_cp
=
self
.
with_cp
)
self
.
inplanes
=
planes
*
self
.
block
.
expansion
layer_name
=
'layer{}'
.
format
(
i
+
1
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
res_layers
.
append
(
layer_name
)
mmdet/models/bbox_heads/bbox_head.py
View file @
2de84ef8
...
@@ -79,11 +79,11 @@ class BBoxHead(nn.Module):
...
@@ -79,11 +79,11 @@ class BBoxHead(nn.Module):
return
cls_reg_targets
return
cls_reg_targets
def
loss
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
def
loss
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
):
bbox_weights
,
reduce
=
True
):
losses
=
dict
()
losses
=
dict
()
if
cls_score
is
not
None
:
if
cls_score
is
not
None
:
losses
[
'loss_cls'
]
=
weighted_cross_entropy
(
losses
[
'loss_cls'
]
=
weighted_cross_entropy
(
cls_score
,
labels
,
label_weights
)
cls_score
,
labels
,
label_weights
,
reduce
=
reduce
)
losses
[
'acc'
]
=
accuracy
(
cls_score
,
labels
)
losses
[
'acc'
]
=
accuracy
(
cls_score
,
labels
)
if
bbox_pred
is
not
None
:
if
bbox_pred
is
not
None
:
losses
[
'loss_reg'
]
=
weighted_smoothl1
(
losses
[
'loss_reg'
]
=
weighted_smoothl1
(
...
...
mmdet/models/detectors/two_stage.py
View file @
2de84ef8
...
@@ -4,7 +4,7 @@ import torch.nn as nn
...
@@ -4,7 +4,7 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
..
import
builder
from
..
import
builder
from
mmdet.core
import
(
assign_and_sample
,
bbox2roi
,
bbox2result
,
multi_apply
)
from
mmdet.core
import
bbox2roi
,
bbox2result
,
build_assigner
,
build_sampler
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
...
@@ -102,13 +102,22 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -102,13 +102,22 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# assign gts and sample proposals
# assign gts and sample proposals
if
self
.
with_bbox
or
self
.
with_mask
:
if
self
.
with_bbox
or
self
.
with_mask
:
assign_results
,
sampling_results
=
multi_apply
(
bbox_assigner
=
build_assigner
(
self
.
train_cfg
.
rcnn
.
assigner
)
assign_and_sample
,
bbox_sampler
=
build_sampler
(
proposal_list
,
self
.
train_cfg
.
rcnn
.
sampler
,
context
=
self
)
gt_bboxes
,
num_imgs
=
img
.
size
(
0
)
gt_bboxes_ignore
,
sampling_results
=
[]
gt_labels
,
for
i
in
range
(
num_imgs
):
cfg
=
self
.
train_cfg
.
rcnn
)
assign_result
=
bbox_assigner
.
assign
(
proposal_list
[
i
],
gt_bboxes
[
i
],
gt_bboxes_ignore
[
i
],
gt_labels
[
i
])
sampling_result
=
bbox_sampler
.
sample
(
assign_result
,
proposal_list
[
i
],
gt_bboxes
[
i
],
gt_labels
[
i
],
feats
=
[
lvl_feat
[
i
][
None
]
for
lvl_feat
in
x
])
sampling_results
.
append
(
sampling_result
)
# bbox head forward and loss
# bbox head forward and loss
if
self
.
with_bbox
:
if
self
.
with_bbox
:
...
...
mmdet/models/mask_heads/fcn_mask_head.py
View file @
2de84ef8
...
@@ -97,7 +97,11 @@ class FCNMaskHead(nn.Module):
...
@@ -97,7 +97,11 @@ class FCNMaskHead(nn.Module):
def
loss
(
self
,
mask_pred
,
mask_targets
,
labels
):
def
loss
(
self
,
mask_pred
,
mask_targets
,
labels
):
loss
=
dict
()
loss
=
dict
()
loss_mask
=
mask_cross_entropy
(
mask_pred
,
mask_targets
,
labels
)
if
self
.
class_agnostic
:
loss_mask
=
mask_cross_entropy
(
mask_pred
,
mask_targets
,
torch
.
zeros_like
(
labels
))
else
:
loss_mask
=
mask_cross_entropy
(
mask_pred
,
mask_targets
,
labels
)
loss
[
'loss_mask'
]
=
loss_mask
loss
[
'loss_mask'
]
=
loss_mask
return
loss
return
loss
...
...
tools/train.py
View file @
2de84ef8
...
@@ -14,6 +14,8 @@ def parse_args():
...
@@ -14,6 +14,8 @@ def parse_args():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--work_dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--work_dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--resume_from'
,
help
=
'the checkpoint file to resume from'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--validate'
,
'--validate'
,
action
=
'store_true'
,
action
=
'store_true'
,
...
@@ -43,6 +45,8 @@ def main():
...
@@ -43,6 +45,8 @@ def main():
# update configs according to CLI args
# update configs according to CLI args
if
args
.
work_dir
is
not
None
:
if
args
.
work_dir
is
not
None
:
cfg
.
work_dir
=
args
.
work_dir
cfg
.
work_dir
=
args
.
work_dir
if
args
.
resume_from
is
not
None
:
cfg
.
resume_from
=
args
.
resume_from
cfg
.
gpus
=
args
.
gpus
cfg
.
gpus
=
args
.
gpus
if
cfg
.
checkpoint_config
is
not
None
:
if
cfg
.
checkpoint_config
is
not
None
:
# save mmdet version in checkpoints as meta data
# save mmdet version in checkpoints as meta data
...
@@ -67,6 +71,7 @@ def main():
...
@@ -67,6 +71,7 @@ def main():
model
=
build_detector
(
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
train_dataset
=
get_dataset
(
cfg
.
data
.
train
)
train_dataset
=
get_dataset
(
cfg
.
data
.
train
)
train_detector
(
train_detector
(
model
,
model
,
...
...
tools/voc_eval.py
0 → 100644
View file @
2de84ef8
from
argparse
import
ArgumentParser
import
mmcv
import
numpy
as
np
from
mmdet
import
datasets
from
mmdet.core
import
eval_map
def
voc_eval
(
result_file
,
dataset
,
iou_thr
=
0.5
):
det_results
=
mmcv
.
load
(
result_file
)
gt_bboxes
=
[]
gt_labels
=
[]
gt_ignore
=
[]
for
i
in
range
(
len
(
dataset
)):
ann
=
dataset
.
get_ann_info
(
i
)
bboxes
=
ann
[
'bboxes'
]
labels
=
ann
[
'labels'
]
if
'bboxes_ignore'
in
ann
:
ignore
=
np
.
concatenate
([
np
.
zeros
(
bboxes
.
shape
[
0
],
dtype
=
np
.
bool
),
np
.
ones
(
ann
[
'bboxes_ignore'
].
shape
[
0
],
dtype
=
np
.
bool
)
])
gt_ignore
.
append
(
ignore
)
bboxes
=
np
.
vstack
([
bboxes
,
ann
[
'bboxes_ignore'
]])
labels
=
np
.
concatenate
([
labels
,
ann
[
'labels_ignore'
]])
gt_bboxes
.
append
(
bboxes
)
gt_labels
.
append
(
labels
)
if
not
gt_ignore
:
gt_ignore
=
gt_ignore
if
hasattr
(
dataset
,
'year'
)
and
dataset
.
year
==
2007
:
dataset_name
=
'voc07'
else
:
dataset_name
=
dataset
.
CLASSES
eval_map
(
det_results
,
gt_bboxes
,
gt_labels
,
gt_ignore
=
gt_ignore
,
scale_ranges
=
None
,
iou_thr
=
iou_thr
,
dataset
=
dataset_name
,
print_summary
=
True
)
def
main
():
parser
=
ArgumentParser
(
description
=
'VOC Evaluation'
)
parser
.
add_argument
(
'result'
,
help
=
'result file path'
)
parser
.
add_argument
(
'config'
,
help
=
'config file path'
)
parser
.
add_argument
(
'--iou-thr'
,
type
=
float
,
default
=
0.5
,
help
=
'IoU threshold for evaluation'
)
args
=
parser
.
parse_args
()
cfg
=
mmcv
.
Config
.
fromfile
(
args
.
config
)
test_dataset
=
mmcv
.
runner
.
obj_from_dict
(
cfg
.
data
.
test
,
datasets
)
voc_eval
(
args
.
result
,
test_dataset
,
args
.
iou_thr
)
if
__name__
==
'__main__'
:
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment