Commit 45af4242 authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'dev' into single-stage

parents e8d16bf2 5686a375
...@@ -108,8 +108,8 @@ class MaskTestMixin(object): ...@@ -108,8 +108,8 @@ class MaskTestMixin(object):
x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois) x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
mask_pred = self.mask_head(mask_feats) mask_pred = self.mask_head(mask_feats)
segm_result = self.mask_head.get_seg_masks( segm_result = self.mask_head.get_seg_masks(
mask_pred, det_bboxes, det_labels, self.test_cfg.rcnn, mask_pred, _bboxes, det_labels, self.test_cfg.rcnn, ori_shape,
ori_shape) scale_factor, rescale)
return segm_result return segm_result
def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels): def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
......
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from mmdet.core import bbox2roi, bbox2result, split_combined_polys, multi_apply from mmdet.core import sample_bboxes, bbox2roi, bbox2result, multi_apply
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...@@ -97,13 +97,14 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -97,13 +97,14 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
proposal_list = proposals proposal_list = proposals
if self.with_bbox: if self.with_bbox:
rcnn_train_cfg_list = [
self.train_cfg.rcnn for _ in range(len(proposal_list))
]
(pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes, (pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes,
pos_gt_labels) = multi_apply( pos_gt_labels) = multi_apply(
self.bbox_roi_extractor.sample_proposals, proposal_list, sample_bboxes,
gt_bboxes, gt_bboxes_ignore, gt_labels, rcnn_train_cfg_list) proposal_list,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
cfg=self.train_cfg.rcnn)
(labels, label_weights, bbox_targets, (labels, label_weights, bbox_targets,
bbox_weights) = self.bbox_head.get_bbox_target( bbox_weights) = self.bbox_head.get_bbox_target(
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels, pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
...@@ -124,9 +125,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -124,9 +125,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
losses.update(loss_bbox) losses.update(loss_bbox)
if self.with_mask: if self.with_mask:
gt_polys = split_combined_polys(**gt_masks)
mask_targets = self.mask_head.get_mask_target( mask_targets = self.mask_head.get_mask_target(
pos_proposals, pos_assigned_gt_inds, gt_polys, img_meta, pos_proposals, pos_assigned_gt_inds, gt_masks,
self.train_cfg.rcnn) self.train_cfg.rcnn)
pos_rois = bbox2roi(pos_proposals) pos_rois = bbox2roi(pos_proposals)
mask_feats = self.mask_roi_extractor( mask_feats = self.mask_roi_extractor(
......
...@@ -87,9 +87,9 @@ class FCNMaskHead(nn.Module): ...@@ -87,9 +87,9 @@ class FCNMaskHead(nn.Module):
return mask_pred return mask_pred
def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks, def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks,
img_meta, rcnn_train_cfg): rcnn_train_cfg):
mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
gt_masks, img_meta, rcnn_train_cfg) gt_masks, rcnn_train_cfg)
return mask_targets return mask_targets
def loss(self, mask_pred, mask_targets, labels): def loss(self, mask_pred, mask_targets, labels):
...@@ -99,8 +99,9 @@ class FCNMaskHead(nn.Module): ...@@ -99,8 +99,9 @@ class FCNMaskHead(nn.Module):
return loss return loss
def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
ori_shape): ori_shape, scale_factor, rescale):
"""Get segmentation masks from mask_pred and bboxes """Get segmentation masks from mask_pred and bboxes.
Args: Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w). mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
For single-scale testing, mask_pred is the direct output of For single-scale testing, mask_pred is the direct output of
...@@ -111,6 +112,7 @@ class FCNMaskHead(nn.Module): ...@@ -111,6 +112,7 @@ class FCNMaskHead(nn.Module):
img_shape (Tensor): shape (3, ) img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config rcnn_test_cfg (dict): rcnn testing config
ori_shape: original image size ori_shape: original image size
Returns: Returns:
list[list]: encoded masks list[list]: encoded masks
""" """
...@@ -119,65 +121,34 @@ class FCNMaskHead(nn.Module): ...@@ -119,65 +121,34 @@ class FCNMaskHead(nn.Module):
assert isinstance(mask_pred, np.ndarray) assert isinstance(mask_pred, np.ndarray)
cls_segms = [[] for _ in range(self.num_classes - 1)] cls_segms = [[] for _ in range(self.num_classes - 1)]
mask_size = mask_pred.shape[-1]
bboxes = det_bboxes.cpu().numpy()[:, :4] bboxes = det_bboxes.cpu().numpy()[:, :4]
labels = det_labels.cpu().numpy() + 1 labels = det_labels.cpu().numpy() + 1
img_h = ori_shape[0]
img_w = ori_shape[1]
scale = (mask_size + 2.0) / mask_size if rescale:
bboxes = np.round(self._bbox_scaling(bboxes, scale)).astype(np.int32) img_h, img_w = ori_shape[:2]
padded_mask = np.zeros( else:
(mask_size + 2, mask_size + 2), dtype=np.float32) img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
scale_factor = 1.0
for i in range(bboxes.shape[0]): for i in range(bboxes.shape[0]):
bbox = bboxes[i, :].astype(int) bbox = (bboxes[i, :] / scale_factor).astype(np.int32)
label = labels[i] label = labels[i]
w = bbox[2] - bbox[0] + 1 w = max(bbox[2] - bbox[0] + 1, 1)
h = bbox[3] - bbox[1] + 1 h = max(bbox[3] - bbox[1] + 1, 1)
w = max(w, 1)
h = max(h, 1)
if not self.class_agnostic: if not self.class_agnostic:
padded_mask[1:-1, 1:-1] = mask_pred[i, label, :, :] mask_pred_ = mask_pred[i, label, :, :]
else: else:
padded_mask[1:-1, 1:-1] = mask_pred[i, 0, :, :] mask_pred_ = mask_pred[i, 0, :, :]
mask = mmcv.imresize(padded_mask, (w, h))
mask = np.array(
mask > rcnn_test_cfg.mask_thr_binary, dtype=np.uint8)
im_mask = np.zeros((img_h, img_w), dtype=np.uint8) im_mask = np.zeros((img_h, img_w), dtype=np.uint8)
x0 = max(bbox[0], 0) bbox_mask = mmcv.imresize(mask_pred_, (w, h))
x1 = min(bbox[2] + 1, img_w) bbox_mask = (bbox_mask > rcnn_test_cfg.mask_thr_binary).astype(
y0 = max(bbox[1], 0) np.uint8)
y1 = min(bbox[3] + 1, img_h) im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = bbox_mask
im_mask[y0:y1, x0:x1] = mask[(y0 - bbox[1]):(y1 - bbox[1]), (
x0 - bbox[0]):(x1 - bbox[0])]
rle = mask_util.encode( rle = mask_util.encode(
np.array(im_mask[:, :, np.newaxis], order='F'))[0] np.array(im_mask[:, :, np.newaxis], order='F'))[0]
cls_segms[label - 1].append(rle) cls_segms[label - 1].append(rle)
return cls_segms
def _bbox_scaling(self, bboxes, scale, clip_shape=None): return cls_segms
"""Scaling bboxes and clip the boundary(optional)
Args:
bboxes(ndarray): shape(..., 4)
scale(float): scaling factor
clip(None or tuple): (h, w)
Returns:
ndarray: scaled bboxes
"""
if float(scale) == 1.0:
scaled_bboxes = bboxes.copy()
else:
w = bboxes[..., 2] - bboxes[..., 0] + 1
h = bboxes[..., 3] - bboxes[..., 1] + 1
dw = (w * (scale - 1)) * 0.5
dh = (h * (scale - 1)) * 0.5
scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
if clip_shape is not None:
return bbox_clip(scaled_bboxes, clip_shape)
else:
return scaled_bboxes
...@@ -111,7 +111,8 @@ class FPN(nn.Module): ...@@ -111,7 +111,8 @@ class FPN(nn.Module):
] ]
# part 2: add extra levels # part 2: add extra levels
if self.num_outs > len(outs): if self.num_outs > len(outs):
# use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN) # use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
if not self.add_extra_convs: if not self.add_extra_convs:
for i in range(self.num_outs - used_backbone_levels): for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 1, stride=2)) outs.append(F.max_pool2d(outs[-1], 1, stride=2))
......
from .single_level import SingleLevelRoI from .single_level import SingleRoIExtractor
__all__ = ['SingleLevelRoI'] __all__ = ['SingleRoIExtractor']
...@@ -4,19 +4,27 @@ import torch ...@@ -4,19 +4,27 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmdet import ops from mmdet import ops
from mmdet.core import bbox_assign, bbox_sampling
class SingleLevelRoI(nn.Module): class SingleRoIExtractor(nn.Module):
"""Extract RoI features from a single level feature map. Each RoI is """Extract RoI features from a single level feature map.
mapped to a level according to its scale."""
If there are mulitple input feature levels, each RoI is mapped to a level
according to its scale.
Args:
roi_layer (dict): Specify RoI layer type and arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (int): Strides of input feature maps.
finest_scale (int): Scale threshold of mapping to level 0.
"""
def __init__(self, def __init__(self,
roi_layer, roi_layer,
out_channels, out_channels,
featmap_strides, featmap_strides,
finest_scale=56): finest_scale=56):
super(SingleLevelRoI, self).__init__() super(SingleRoIExtractor, self).__init__()
self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides) self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
self.out_channels = out_channels self.out_channels = out_channels
self.featmap_strides = featmap_strides self.featmap_strides = featmap_strides
...@@ -24,6 +32,7 @@ class SingleLevelRoI(nn.Module): ...@@ -24,6 +32,7 @@ class SingleLevelRoI(nn.Module):
@property @property
def num_inputs(self): def num_inputs(self):
"""int: Input feature map levels."""
return len(self.featmap_strides) return len(self.featmap_strides)
def init_weights(self): def init_weights(self):
...@@ -39,12 +48,19 @@ class SingleLevelRoI(nn.Module): ...@@ -39,12 +48,19 @@ class SingleLevelRoI(nn.Module):
return roi_layers return roi_layers
def map_roi_levels(self, rois, num_levels): def map_roi_levels(self, rois, num_levels):
"""Map rois to corresponding feature levels (0-based) by scales. """Map rois to corresponding feature levels by scales.
- scale < finest_scale: level 0 - scale < finest_scale: level 0
- finest_scale <= scale < finest_scale * 2: level 1 - finest_scale <= scale < finest_scale * 2: level 1
- finest_scale * 2 <= scale < finest_scale * 4: level 2 - finest_scale * 2 <= scale < finest_scale * 4: level 2
- scale >= finest_scale * 4: level 3 - scale >= finest_scale * 4: level 3
Args:
rois (Tensor): Input RoIs, shape (k, 5).
num_levels (int): Total level number.
Returns:
Tensor: Level index (0-based) of each RoI, shape (k, )
""" """
scale = torch.sqrt( scale = torch.sqrt(
(rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1)) (rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1))
...@@ -52,43 +68,7 @@ class SingleLevelRoI(nn.Module): ...@@ -52,43 +68,7 @@ class SingleLevelRoI(nn.Module):
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls return target_lvls
def sample_proposals(self, proposals, gt_bboxes, gt_bboxes_ignore,
gt_labels, cfg):
proposals = proposals[:, :4]
assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
bbox_assign(proposals, gt_bboxes, gt_bboxes_ignore, gt_labels,
cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.min_pos_iou,
cfg.crowd_thr)
if cfg.add_gt_as_proposals:
proposals = torch.cat([gt_bboxes, proposals], dim=0)
gt_assign_self = torch.arange(
1,
len(gt_labels) + 1,
dtype=torch.long,
device=proposals.device)
assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
assigned_labels = torch.cat([gt_labels, assigned_labels])
pos_inds, neg_inds = bbox_sampling(
assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction,
cfg.neg_pos_ub, cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
pos_proposals = proposals[pos_inds]
neg_proposals = proposals[neg_inds]
pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
pos_gt_labels = assigned_labels[pos_inds]
return (pos_proposals, neg_proposals, pos_assigned_gt_inds,
pos_gt_bboxes, pos_gt_labels)
def forward(self, feats, rois): def forward(self, feats, rois):
"""Extract roi features with the roi layer. If multiple feature levels
are used, then rois are mapped to corresponding levels according to
their scales.
"""
if len(feats) == 1: if len(feats) == 1:
return self.roi_layers[0](feats[0], rois) return self.roi_layers[0](feats[0], rois)
......
from .conv_module import ConvModule from .conv_module import ConvModule
from .norm import build_norm_layer from .norm import build_norm_layer
from .weight_init import * from .weight_init import xavier_init, normal_init, uniform_init, kaiming_init
__all__ = ['ConvModule', 'build_norm_layer'] __all__ = [
'ConvModule', 'build_norm_layer', 'xavier_init', 'normal_init',
'uniform_init', 'kaiming_init'
]
from .nms import nms, soft_nms from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool from .roi_pool import RoIPool, roi_pool
__all__ = ['nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool']
from .nms_wrapper import nms, soft_nms from .nms_wrapper import nms, soft_nms
__all__ = ['nms', 'soft_nms']
from .functions.roi_align import roi_align from .functions.roi_align import roi_align
from .modules.roi_align import RoIAlign from .modules.roi_align import RoIAlign
__all__ = ['roi_align', 'RoIAlign']
...@@ -5,7 +5,7 @@ from torch.autograd import gradcheck ...@@ -5,7 +5,7 @@ from torch.autograd import gradcheck
import os.path as osp import os.path as osp
import sys import sys
sys.path.append(osp.abspath(osp.join(__file__, '../../'))) sys.path.append(osp.abspath(osp.join(__file__, '../../')))
from roi_align import RoIAlign from roi_align import RoIAlign # noqa: E402
feat_size = 15 feat_size = 15
spatial_scale = 1.0 / 8 spatial_scale = 1.0 / 8
......
from .functions.roi_pool import roi_pool from .functions.roi_pool import roi_pool
from .modules.roi_pool import RoIPool from .modules.roi_pool import RoIPool
__all__ = ['roi_pool', 'RoIPool']
...@@ -4,7 +4,7 @@ from torch.autograd import gradcheck ...@@ -4,7 +4,7 @@ from torch.autograd import gradcheck
import os.path as osp import os.path as osp
import sys import sys
sys.path.append(osp.abspath(osp.join(__file__, '../../'))) sys.path.append(osp.abspath(osp.join(__file__, '../../')))
from roi_pool import RoIPool from roi_pool import RoIPool # noqa: E402
feat = torch.randn(4, 16, 15, 15, requires_grad=True).cuda() feat = torch.randn(4, 16, 15, 15, requires_grad=True).cuda()
rois = torch.Tensor([[0, 0, 0, 50, 50], [0, 10, 30, 43, 55], rois = torch.Tensor([[0, 0, 0, 50, 50], [0, 10, 30, 43, 55],
......
...@@ -61,7 +61,7 @@ def get_hash(): ...@@ -61,7 +61,7 @@ def get_hash():
def write_version_py(): def write_version_py():
content = """# GENERATED VERSION FILE content = """# GENERATED VERSION FILE
# TIME: {} # TIME: {}
__version__ = '{}' __version__ = '{}'
...@@ -88,7 +88,9 @@ if __name__ == '__main__': ...@@ -88,7 +88,9 @@ if __name__ == '__main__':
description='Open MMLab Detection Toolbox', description='Open MMLab Detection Toolbox',
long_description=readme(), long_description=readme(),
keywords='computer vision, object detection', keywords='computer vision, object detection',
url='https://github.com/open-mmlab/mmdetection',
packages=find_packages(), packages=find_packages(),
package_data={'mmdet.ops': ['*/*.so']},
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',
'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)',
...@@ -99,10 +101,11 @@ if __name__ == '__main__': ...@@ -99,10 +101,11 @@ if __name__ == '__main__':
'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Topic :: Utilities',
], ],
license='GPLv3', license='GPLv3',
setup_requires=['pytest-runner'], setup_requires=['pytest-runner'],
tests_require=['pytest'], tests_require=['pytest'],
install_requires=['numpy', 'matplotlib', 'six', 'terminaltables'], install_requires=[
'numpy', 'matplotlib', 'six', 'terminaltables', 'pycocotools'
],
zip_safe=False) zip_safe=False)
...@@ -25,7 +25,7 @@ model = dict( ...@@ -25,7 +25,7 @@ model = dict(
target_stds=[1.0, 1.0, 1.0, 1.0], target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True), use_sigmoid_cls=True),
bbox_roi_extractor=dict( bbox_roi_extractor=dict(
type='SingleLevelRoI', type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256, out_channels=256,
featmap_strides=[4, 8, 16, 32]), featmap_strides=[4, 8, 16, 32]),
...@@ -131,7 +131,7 @@ lr_config = dict( ...@@ -131,7 +131,7 @@ lr_config = dict(
checkpoint_config = dict(interval=1) checkpoint_config = dict(interval=1)
# yapf:disable # yapf:disable
log_config = dict( log_config = dict(
interval=20, interval=50,
hooks=[ hooks=[
dict(type='TextLoggerHook'), dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log') # dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log')
......
...@@ -25,7 +25,7 @@ model = dict( ...@@ -25,7 +25,7 @@ model = dict(
target_stds=[1.0, 1.0, 1.0, 1.0], target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True), use_sigmoid_cls=True),
bbox_roi_extractor=dict( bbox_roi_extractor=dict(
type='SingleLevelRoI', type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256, out_channels=256,
featmap_strides=[4, 8, 16, 32]), featmap_strides=[4, 8, 16, 32]),
...@@ -40,7 +40,7 @@ model = dict( ...@@ -40,7 +40,7 @@ model = dict(
target_stds=[0.1, 0.1, 0.2, 0.2], target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False), reg_class_agnostic=False),
mask_roi_extractor=dict( mask_roi_extractor=dict(
type='SingleLevelRoI', type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=256, out_channels=256,
featmap_strides=[4, 8, 16, 32]), featmap_strides=[4, 8, 16, 32]),
...@@ -144,10 +144,10 @@ lr_config = dict( ...@@ -144,10 +144,10 @@ lr_config = dict(
checkpoint_config = dict(interval=1) checkpoint_config = dict(interval=1)
# yapf:disable # yapf:disable
log_config = dict( log_config = dict(
interval=20, interval=50,
hooks=[ hooks=[
dict(type='TextLoggerHook'), dict(type='TextLoggerHook'),
# ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')), # dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log')
]) ])
# yapf:enable # yapf:enable
# runtime settings # runtime settings
......
...@@ -6,7 +6,7 @@ from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict ...@@ -6,7 +6,7 @@ from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict
from mmdet import datasets from mmdet import datasets
from mmdet.core import scatter, MMDataParallel, results2json, coco_eval from mmdet.core import scatter, MMDataParallel, results2json, coco_eval
from mmdet.datasets.loader import collate, build_dataloader from mmdet.datasets import collate, build_dataloader
from mmdet.models import build_detector, detectors from mmdet.models import build_detector, detectors
......
...@@ -13,7 +13,7 @@ from mmdet import datasets, __version__ ...@@ -13,7 +13,7 @@ from mmdet import datasets, __version__
from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook, from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook,
MMDataParallel, MMDistributedDataParallel, MMDataParallel, MMDistributedDataParallel,
CocoDistEvalRecallHook, CocoDistEvalmAPHook) CocoDistEvalRecallHook, CocoDistEvalmAPHook)
from mmdet.datasets.loader import build_dataloader from mmdet.datasets import build_dataloader
from mmdet.models import build_detector, RPN from mmdet.models import build_detector, RPN
...@@ -90,7 +90,8 @@ def main(): ...@@ -90,7 +90,8 @@ def main():
cfg.work_dir = args.work_dir cfg.work_dir = args.work_dir
cfg.gpus = args.gpus cfg.gpus = args.gpus
# add mmdet version to checkpoint as meta data # add mmdet version to checkpoint as meta data
cfg.checkpoint_config.meta = dict(mmdet_version=__version__) cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text)
logger = get_logger(cfg.log_level) logger = get_logger(cfg.log_level)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment