Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
45af4242
"eigen-master/Eigen/src/plugins/ArrayCwiseBinaryOps.inc" did not exist on "e7df86554156b36846008d8ddbcc4d8521a16554"
Commit
45af4242
authored
Oct 07, 2018
by
Kai Chen
Browse files
Merge branch 'dev' into single-stage
parents
e8d16bf2
5686a375
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
89 additions
and
122 deletions
+89
-122
mmdet/models/detectors/test_mixins.py
mmdet/models/detectors/test_mixins.py
+2
-2
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+8
-8
mmdet/models/mask_heads/fcn_mask_head.py
mmdet/models/mask_heads/fcn_mask_head.py
+22
-51
mmdet/models/necks/fpn.py
mmdet/models/necks/fpn.py
+2
-1
mmdet/models/roi_extractors/__init__.py
mmdet/models/roi_extractors/__init__.py
+2
-2
mmdet/models/roi_extractors/single_level.py
mmdet/models/roi_extractors/single_level.py
+22
-42
mmdet/models/utils/__init__.py
mmdet/models/utils/__init__.py
+5
-2
mmdet/ops/__init__.py
mmdet/ops/__init__.py
+2
-0
mmdet/ops/nms/__init__.py
mmdet/ops/nms/__init__.py
+2
-0
mmdet/ops/roi_align/__init__.py
mmdet/ops/roi_align/__init__.py
+2
-0
mmdet/ops/roi_align/gradcheck.py
mmdet/ops/roi_align/gradcheck.py
+1
-1
mmdet/ops/roi_pool/__init__.py
mmdet/ops/roi_pool/__init__.py
+2
-0
mmdet/ops/roi_pool/gradcheck.py
mmdet/ops/roi_pool/gradcheck.py
+1
-1
setup.py
setup.py
+6
-3
tools/configs/r50_fpn_frcnn_1x.py
tools/configs/r50_fpn_frcnn_1x.py
+2
-2
tools/configs/r50_fpn_maskrcnn_1x.py
tools/configs/r50_fpn_maskrcnn_1x.py
+4
-4
tools/test.py
tools/test.py
+1
-1
tools/train.py
tools/train.py
+3
-2
No files found.
mmdet/models/detectors/test_mixins.py
View file @
45af4242
...
...
@@ -108,8 +108,8 @@ class MaskTestMixin(object):
x
[:
len
(
self
.
mask_roi_extractor
.
featmap_strides
)],
mask_rois
)
mask_pred
=
self
.
mask_head
(
mask_feats
)
segm_result
=
self
.
mask_head
.
get_seg_masks
(
mask_pred
,
det
_bboxes
,
det_labels
,
self
.
test_cfg
.
rcnn
,
ori_shap
e
)
mask_pred
,
_bboxes
,
det_labels
,
self
.
test_cfg
.
rcnn
,
ori_shape
,
scale_factor
,
rescal
e
)
return
segm_result
def
aug_test_mask
(
self
,
feats
,
img_metas
,
det_bboxes
,
det_labels
):
...
...
mmdet/models/detectors/two_stage.py
View file @
45af4242
...
...
@@ -4,7 +4,7 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
..
import
builder
from
mmdet.core
import
bbox2roi
,
bbox2result
,
split_combined_polys
,
multi_apply
from
mmdet.core
import
sample_bboxes
,
bbox2roi
,
bbox2result
,
multi_apply
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
...
...
@@ -97,13 +97,14 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
proposal_list
=
proposals
if
self
.
with_bbox
:
rcnn_train_cfg_list
=
[
self
.
train_cfg
.
rcnn
for
_
in
range
(
len
(
proposal_list
))
]
(
pos_proposals
,
neg_proposals
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
pos_gt_labels
)
=
multi_apply
(
self
.
bbox_roi_extractor
.
sample_proposals
,
proposal_list
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
rcnn_train_cfg_list
)
sample_bboxes
,
proposal_list
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
=
self
.
train_cfg
.
rcnn
)
(
labels
,
label_weights
,
bbox_targets
,
bbox_weights
)
=
self
.
bbox_head
.
get_bbox_target
(
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_labels
,
...
...
@@ -124,9 +125,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
losses
.
update
(
loss_bbox
)
if
self
.
with_mask
:
gt_polys
=
split_combined_polys
(
**
gt_masks
)
mask_targets
=
self
.
mask_head
.
get_mask_target
(
pos_proposals
,
pos_assigned_gt_inds
,
gt_
polys
,
img_meta
,
pos_proposals
,
pos_assigned_gt_inds
,
gt_
masks
,
self
.
train_cfg
.
rcnn
)
pos_rois
=
bbox2roi
(
pos_proposals
)
mask_feats
=
self
.
mask_roi_extractor
(
...
...
mmdet/models/mask_heads/fcn_mask_head.py
View file @
45af4242
...
...
@@ -87,9 +87,9 @@ class FCNMaskHead(nn.Module):
return
mask_pred
def
get_mask_target
(
self
,
pos_proposals
,
pos_assigned_gt_inds
,
gt_masks
,
img_meta
,
rcnn_train_cfg
):
rcnn_train_cfg
):
mask_targets
=
mask_target
(
pos_proposals
,
pos_assigned_gt_inds
,
gt_masks
,
img_meta
,
rcnn_train_cfg
)
gt_masks
,
rcnn_train_cfg
)
return
mask_targets
def
loss
(
self
,
mask_pred
,
mask_targets
,
labels
):
...
...
@@ -99,8 +99,9 @@ class FCNMaskHead(nn.Module):
return
loss
def
get_seg_masks
(
self
,
mask_pred
,
det_bboxes
,
det_labels
,
rcnn_test_cfg
,
ori_shape
):
"""Get segmentation masks from mask_pred and bboxes
ori_shape
,
scale_factor
,
rescale
):
"""Get segmentation masks from mask_pred and bboxes.
Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
For single-scale testing, mask_pred is the direct output of
...
...
@@ -111,6 +112,7 @@ class FCNMaskHead(nn.Module):
img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config
ori_shape: original image size
Returns:
list[list]: encoded masks
"""
...
...
@@ -119,65 +121,34 @@ class FCNMaskHead(nn.Module):
assert
isinstance
(
mask_pred
,
np
.
ndarray
)
cls_segms
=
[[]
for
_
in
range
(
self
.
num_classes
-
1
)]
mask_size
=
mask_pred
.
shape
[
-
1
]
bboxes
=
det_bboxes
.
cpu
().
numpy
()[:,
:
4
]
labels
=
det_labels
.
cpu
().
numpy
()
+
1
img_h
=
ori_shape
[
0
]
img_w
=
ori_shape
[
1
]
scale
=
(
mask_size
+
2.0
)
/
mask_size
bboxes
=
np
.
round
(
self
.
_bbox_scaling
(
bboxes
,
scale
)).
astype
(
np
.
int32
)
padded_mask
=
np
.
zeros
(
(
mask_size
+
2
,
mask_size
+
2
),
dtype
=
np
.
float32
)
if
rescale
:
img_h
,
img_w
=
ori_shape
[:
2
]
else
:
img_h
=
np
.
round
(
ori_shape
[
0
]
*
scale_factor
).
astype
(
np
.
int32
)
img_w
=
np
.
round
(
ori_shape
[
1
]
*
scale_factor
).
astype
(
np
.
int32
)
scale_factor
=
1.0
for
i
in
range
(
bboxes
.
shape
[
0
]):
bbox
=
bboxes
[
i
,
:].
astype
(
int
)
bbox
=
(
bboxes
[
i
,
:]
/
scale_factor
)
.
astype
(
np
.
int
32
)
label
=
labels
[
i
]
w
=
bbox
[
2
]
-
bbox
[
0
]
+
1
h
=
bbox
[
3
]
-
bbox
[
1
]
+
1
w
=
max
(
w
,
1
)
h
=
max
(
h
,
1
)
w
=
max
(
bbox
[
2
]
-
bbox
[
0
]
+
1
,
1
)
h
=
max
(
bbox
[
3
]
-
bbox
[
1
]
+
1
,
1
)
if
not
self
.
class_agnostic
:
padded_mask
[
1
:
-
1
,
1
:
-
1
]
=
mask_pred
[
i
,
label
,
:,
:]
mask_pred_
=
mask_pred
[
i
,
label
,
:,
:]
else
:
padded_mask
[
1
:
-
1
,
1
:
-
1
]
=
mask_pred
[
i
,
0
,
:,
:]
mask
=
mmcv
.
imresize
(
padded_mask
,
(
w
,
h
))
mask
=
np
.
array
(
mask
>
rcnn_test_cfg
.
mask_thr_binary
,
dtype
=
np
.
uint8
)
mask_pred_
=
mask_pred
[
i
,
0
,
:,
:]
im_mask
=
np
.
zeros
((
img_h
,
img_w
),
dtype
=
np
.
uint8
)
x0
=
max
(
bbox
[
0
],
0
)
x1
=
min
(
bbox
[
2
]
+
1
,
img_w
)
y0
=
max
(
bbox
[
1
],
0
)
y1
=
min
(
bbox
[
3
]
+
1
,
img_h
)
im_mask
[
y0
:
y1
,
x0
:
x1
]
=
mask
[(
y0
-
bbox
[
1
]):(
y1
-
bbox
[
1
]),
(
x0
-
bbox
[
0
]):(
x1
-
bbox
[
0
])]
bbox_mask
=
mmcv
.
imresize
(
mask_pred_
,
(
w
,
h
))
bbox_mask
=
(
bbox_mask
>
rcnn_test_cfg
.
mask_thr_binary
).
astype
(
np
.
uint8
)
im_mask
[
bbox
[
1
]:
bbox
[
1
]
+
h
,
bbox
[
0
]:
bbox
[
0
]
+
w
]
=
bbox_mask
rle
=
mask_util
.
encode
(
np
.
array
(
im_mask
[:,
:,
np
.
newaxis
],
order
=
'F'
))[
0
]
cls_segms
[
label
-
1
].
append
(
rle
)
return
cls_segms
def
_bbox_scaling
(
self
,
bboxes
,
scale
,
clip_shape
=
None
):
"""Scaling bboxes and clip the boundary(optional)
Args:
bboxes(ndarray): shape(..., 4)
scale(float): scaling factor
clip(None or tuple): (h, w)
Returns:
ndarray: scaled bboxes
"""
if
float
(
scale
)
==
1.0
:
scaled_bboxes
=
bboxes
.
copy
()
else
:
w
=
bboxes
[...,
2
]
-
bboxes
[...,
0
]
+
1
h
=
bboxes
[...,
3
]
-
bboxes
[...,
1
]
+
1
dw
=
(
w
*
(
scale
-
1
))
*
0.5
dh
=
(
h
*
(
scale
-
1
))
*
0.5
scaled_bboxes
=
bboxes
+
np
.
stack
((
-
dw
,
-
dh
,
dw
,
dh
),
axis
=-
1
)
if
clip_shape
is
not
None
:
return
bbox_clip
(
scaled_bboxes
,
clip_shape
)
else
:
return
scaled_bboxes
return
cls_segms
mmdet/models/necks/fpn.py
View file @
45af4242
...
...
@@ -111,7 +111,8 @@ class FPN(nn.Module):
]
# part 2: add extra levels
if
self
.
num_outs
>
len
(
outs
):
# use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN)
# use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
if
not
self
.
add_extra_convs
:
for
i
in
range
(
self
.
num_outs
-
used_backbone_levels
):
outs
.
append
(
F
.
max_pool2d
(
outs
[
-
1
],
1
,
stride
=
2
))
...
...
mmdet/models/roi_extractors/__init__.py
View file @
45af4242
from
.single_level
import
Single
LevelRoI
from
.single_level
import
Single
RoIExtractor
__all__
=
[
'Single
LevelRoI
'
]
__all__
=
[
'Single
RoIExtractor
'
]
mmdet/models/roi_extractors/single_level.py
View file @
45af4242
...
...
@@ -4,19 +4,27 @@ import torch
import
torch.nn
as
nn
from
mmdet
import
ops
from
mmdet.core
import
bbox_assign
,
bbox_sampling
class
SingleLevelRoI
(
nn
.
Module
):
"""Extract RoI features from a single level feature map. Each RoI is
mapped to a level according to its scale."""
class
SingleRoIExtractor
(
nn
.
Module
):
"""Extract RoI features from a single level feature map.
If there are mulitple input feature levels, each RoI is mapped to a level
according to its scale.
Args:
roi_layer (dict): Specify RoI layer type and arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (int): Strides of input feature maps.
finest_scale (int): Scale threshold of mapping to level 0.
"""
def
__init__
(
self
,
roi_layer
,
out_channels
,
featmap_strides
,
finest_scale
=
56
):
super
(
Single
LevelRoI
,
self
).
__init__
()
super
(
Single
RoIExtractor
,
self
).
__init__
()
self
.
roi_layers
=
self
.
build_roi_layers
(
roi_layer
,
featmap_strides
)
self
.
out_channels
=
out_channels
self
.
featmap_strides
=
featmap_strides
...
...
@@ -24,6 +32,7 @@ class SingleLevelRoI(nn.Module):
@
property
def
num_inputs
(
self
):
"""int: Input feature map levels."""
return
len
(
self
.
featmap_strides
)
def
init_weights
(
self
):
...
...
@@ -39,12 +48,19 @@ class SingleLevelRoI(nn.Module):
return
roi_layers
def
map_roi_levels
(
self
,
rois
,
num_levels
):
"""Map rois to corresponding feature levels
(0-based)
by scales.
"""Map rois to corresponding feature levels by scales.
- scale < finest_scale: level 0
- finest_scale <= scale < finest_scale * 2: level 1
- finest_scale * 2 <= scale < finest_scale * 4: level 2
- scale >= finest_scale * 4: level 3
Args:
rois (Tensor): Input RoIs, shape (k, 5).
num_levels (int): Total level number.
Returns:
Tensor: Level index (0-based) of each RoI, shape (k, )
"""
scale
=
torch
.
sqrt
(
(
rois
[:,
3
]
-
rois
[:,
1
]
+
1
)
*
(
rois
[:,
4
]
-
rois
[:,
2
]
+
1
))
...
...
@@ -52,43 +68,7 @@ class SingleLevelRoI(nn.Module):
target_lvls
=
target_lvls
.
clamp
(
min
=
0
,
max
=
num_levels
-
1
).
long
()
return
target_lvls
def
sample_proposals
(
self
,
proposals
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
):
proposals
=
proposals
[:,
:
4
]
assigned_gt_inds
,
assigned_labels
,
argmax_overlaps
,
max_overlaps
=
\
bbox_assign
(
proposals
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
.
pos_iou_thr
,
cfg
.
neg_iou_thr
,
cfg
.
min_pos_iou
,
cfg
.
crowd_thr
)
if
cfg
.
add_gt_as_proposals
:
proposals
=
torch
.
cat
([
gt_bboxes
,
proposals
],
dim
=
0
)
gt_assign_self
=
torch
.
arange
(
1
,
len
(
gt_labels
)
+
1
,
dtype
=
torch
.
long
,
device
=
proposals
.
device
)
assigned_gt_inds
=
torch
.
cat
([
gt_assign_self
,
assigned_gt_inds
])
assigned_labels
=
torch
.
cat
([
gt_labels
,
assigned_labels
])
pos_inds
,
neg_inds
=
bbox_sampling
(
assigned_gt_inds
,
cfg
.
roi_batch_size
,
cfg
.
pos_fraction
,
cfg
.
neg_pos_ub
,
cfg
.
pos_balance_sampling
,
max_overlaps
,
cfg
.
neg_balance_thr
)
pos_proposals
=
proposals
[
pos_inds
]
neg_proposals
=
proposals
[
neg_inds
]
pos_assigned_gt_inds
=
assigned_gt_inds
[
pos_inds
]
-
1
pos_gt_bboxes
=
gt_bboxes
[
pos_assigned_gt_inds
,
:]
pos_gt_labels
=
assigned_labels
[
pos_inds
]
return
(
pos_proposals
,
neg_proposals
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
pos_gt_labels
)
def
forward
(
self
,
feats
,
rois
):
"""Extract roi features with the roi layer. If multiple feature levels
are used, then rois are mapped to corresponding levels according to
their scales.
"""
if
len
(
feats
)
==
1
:
return
self
.
roi_layers
[
0
](
feats
[
0
],
rois
)
...
...
mmdet/models/utils/__init__.py
View file @
45af4242
from
.conv_module
import
ConvModule
from
.norm
import
build_norm_layer
from
.weight_init
import
*
from
.weight_init
import
xavier_init
,
normal_init
,
uniform_init
,
kaiming_init
__all__
=
[
'ConvModule'
,
'build_norm_layer'
]
__all__
=
[
'ConvModule'
,
'build_norm_layer'
,
'xavier_init'
,
'normal_init'
,
'uniform_init'
,
'kaiming_init'
]
mmdet/ops/__init__.py
View file @
45af4242
from
.nms
import
nms
,
soft_nms
from
.roi_align
import
RoIAlign
,
roi_align
from
.roi_pool
import
RoIPool
,
roi_pool
__all__
=
[
'nms'
,
'soft_nms'
,
'RoIAlign'
,
'roi_align'
,
'RoIPool'
,
'roi_pool'
]
mmdet/ops/nms/__init__.py
View file @
45af4242
from
.nms_wrapper
import
nms
,
soft_nms
__all__
=
[
'nms'
,
'soft_nms'
]
mmdet/ops/roi_align/__init__.py
View file @
45af4242
from
.functions.roi_align
import
roi_align
from
.modules.roi_align
import
RoIAlign
__all__
=
[
'roi_align'
,
'RoIAlign'
]
mmdet/ops/roi_align/gradcheck.py
View file @
45af4242
...
...
@@ -5,7 +5,7 @@ from torch.autograd import gradcheck
import
os.path
as
osp
import
sys
sys
.
path
.
append
(
osp
.
abspath
(
osp
.
join
(
__file__
,
'../../'
)))
from
roi_align
import
RoIAlign
from
roi_align
import
RoIAlign
# noqa: E402
feat_size
=
15
spatial_scale
=
1.0
/
8
...
...
mmdet/ops/roi_pool/__init__.py
View file @
45af4242
from
.functions.roi_pool
import
roi_pool
from
.modules.roi_pool
import
RoIPool
__all__
=
[
'roi_pool'
,
'RoIPool'
]
mmdet/ops/roi_pool/gradcheck.py
View file @
45af4242
...
...
@@ -4,7 +4,7 @@ from torch.autograd import gradcheck
import
os.path
as
osp
import
sys
sys
.
path
.
append
(
osp
.
abspath
(
osp
.
join
(
__file__
,
'../../'
)))
from
roi_pool
import
RoIPool
from
roi_pool
import
RoIPool
# noqa: E402
feat
=
torch
.
randn
(
4
,
16
,
15
,
15
,
requires_grad
=
True
).
cuda
()
rois
=
torch
.
Tensor
([[
0
,
0
,
0
,
50
,
50
],
[
0
,
10
,
30
,
43
,
55
],
...
...
setup.py
View file @
45af4242
...
...
@@ -61,7 +61,7 @@ def get_hash():
def
write_version_py
():
content
=
"""# GENERATED VERSION FILE
content
=
"""# GENERATED VERSION FILE
# TIME: {}
__version__ = '{}'
...
...
@@ -88,7 +88,9 @@ if __name__ == '__main__':
description
=
'Open MMLab Detection Toolbox'
,
long_description
=
readme
(),
keywords
=
'computer vision, object detection'
,
url
=
'https://github.com/open-mmlab/mmdetection'
,
packages
=
find_packages
(),
package_data
=
{
'mmdet.ops'
:
[
'*/*.so'
]},
classifiers
=
[
'Development Status :: 4 - Beta'
,
'License :: OSI Approved :: GNU General Public License v3 (GPLv3)'
,
...
...
@@ -99,10 +101,11 @@ if __name__ == '__main__':
'Programming Language :: Python :: 3.4'
,
'Programming Language :: Python :: 3.5'
,
'Programming Language :: Python :: 3.6'
,
'Topic :: Utilities'
,
],
license
=
'GPLv3'
,
setup_requires
=
[
'pytest-runner'
],
tests_require
=
[
'pytest'
],
install_requires
=
[
'numpy'
,
'matplotlib'
,
'six'
,
'terminaltables'
],
install_requires
=
[
'numpy'
,
'matplotlib'
,
'six'
,
'terminaltables'
,
'pycocotools'
],
zip_safe
=
False
)
tools/configs/r50_fpn_frcnn_1x.py
View file @
45af4242
...
...
@@ -25,7 +25,7 @@ model = dict(
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
],
use_sigmoid_cls
=
True
),
bbox_roi_extractor
=
dict
(
type
=
'Single
LevelRoI
'
,
type
=
'Single
RoIExtractor
'
,
roi_layer
=
dict
(
type
=
'RoIAlign'
,
out_size
=
7
,
sample_num
=
2
),
out_channels
=
256
,
featmap_strides
=
[
4
,
8
,
16
,
32
]),
...
...
@@ -131,7 +131,7 @@ lr_config = dict(
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
2
0
,
interval
=
5
0
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log')
...
...
tools/configs/r50_fpn_maskrcnn_1x.py
View file @
45af4242
...
...
@@ -25,7 +25,7 @@ model = dict(
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
],
use_sigmoid_cls
=
True
),
bbox_roi_extractor
=
dict
(
type
=
'Single
LevelRoI
'
,
type
=
'Single
RoIExtractor
'
,
roi_layer
=
dict
(
type
=
'RoIAlign'
,
out_size
=
7
,
sample_num
=
2
),
out_channels
=
256
,
featmap_strides
=
[
4
,
8
,
16
,
32
]),
...
...
@@ -40,7 +40,7 @@ model = dict(
target_stds
=
[
0.1
,
0.1
,
0.2
,
0.2
],
reg_class_agnostic
=
False
),
mask_roi_extractor
=
dict
(
type
=
'Single
LevelRoI
'
,
type
=
'Single
RoIExtractor
'
,
roi_layer
=
dict
(
type
=
'RoIAlign'
,
out_size
=
14
,
sample_num
=
2
),
out_channels
=
256
,
featmap_strides
=
[
4
,
8
,
16
,
32
]),
...
...
@@ -144,10 +144,10 @@ lr_config = dict(
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
2
0
,
interval
=
5
0
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
#
(
'TensorboardLoggerHook',
dict(
log_dir=work_dir + '/log')
),
#
dict(type=
'TensorboardLoggerHook', log_dir=work_dir + '/log')
])
# yapf:enable
# runtime settings
...
...
tools/test.py
View file @
45af4242
...
...
@@ -6,7 +6,7 @@ from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict
from
mmdet
import
datasets
from
mmdet.core
import
scatter
,
MMDataParallel
,
results2json
,
coco_eval
from
mmdet.datasets
.loader
import
collate
,
build_dataloader
from
mmdet.datasets
import
collate
,
build_dataloader
from
mmdet.models
import
build_detector
,
detectors
...
...
tools/train.py
View file @
45af4242
...
...
@@ -13,7 +13,7 @@ from mmdet import datasets, __version__
from
mmdet.core
import
(
init_dist
,
DistOptimizerHook
,
DistSamplerSeedHook
,
MMDataParallel
,
MMDistributedDataParallel
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
mmdet.datasets
.loader
import
build_dataloader
from
mmdet.datasets
import
build_dataloader
from
mmdet.models
import
build_detector
,
RPN
...
...
@@ -90,7 +90,8 @@ def main():
cfg
.
work_dir
=
args
.
work_dir
cfg
.
gpus
=
args
.
gpus
# add mmdet version to checkpoint as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
__version__
)
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
__version__
,
config
=
cfg
.
text
)
logger
=
get_logger
(
cfg
.
log_level
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment