Unverified Commit e8397e43 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #10 from OceanPang/dev

Minor fix bugs 
parents 7d343fd2 d7743255
...@@ -53,8 +53,14 @@ class CocoDataset(Dataset): ...@@ -53,8 +53,14 @@ class CocoDataset(Dataset):
# color channel order and normalize configs # color channel order and normalize configs
self.img_norm_cfg = img_norm_cfg self.img_norm_cfg = img_norm_cfg
# proposals # proposals
self.proposals = mmcv.load( # TODO: revise _filter_imgs to be more flexible
proposal_file) if proposal_file is not None else None if proposal_file is not None:
self.proposals = mmcv.load(proposal_file)
ori_ids = self.coco.getImgIds()
sorted_idx = [ori_ids.index(id) for id in self.img_ids]
self.proposals = [self.proposals[idx] for idx in sorted_idx]
else:
self.proposals = None
self.num_max_proposals = num_max_proposals self.num_max_proposals = num_max_proposals
# flip ratio # flip ratio
self.flip_ratio = flip_ratio self.flip_ratio = flip_ratio
...@@ -271,7 +277,8 @@ class CocoDataset(Dataset): ...@@ -271,7 +277,8 @@ class CocoDataset(Dataset):
scale_factor=scale_factor, scale_factor=scale_factor,
flip=flip) flip=flip)
if proposal is not None: if proposal is not None:
_proposal = self.bbox_transform(proposal, scale_factor, flip) _proposal = self.bbox_transform(proposal, img_shape,
scale_factor, flip)
_proposal = to_tensor(_proposal) _proposal = to_tensor(_proposal)
else: else:
_proposal = None _proposal = None
......
from .base import BaseDetector from .base import BaseDetector
from .rpn import RPN from .rpn import RPN
from .fast_rcnn import FastRCNN
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN
from .mask_rcnn import MaskRCNN from .mask_rcnn import MaskRCNN
__all__ = ['BaseDetector', 'RPN', 'FasterRCNN', 'MaskRCNN'] __all__ = ['BaseDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'MaskRCNN']
from .two_stage import TwoStageDetector
class FastRCNN(TwoStageDetector):
def __init__(self,
backbone,
neck,
bbox_roi_extractor,
bbox_head,
train_cfg,
test_cfg,
mask_roi_extractor=None,
mask_head=None,
pretrained=None):
super(FastRCNN, self).__init__(
backbone=backbone,
neck=neck,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
mask_roi_extractor=mask_roi_extractor,
mask_head=mask_head,
pretrained=pretrained)
def forward_test(self, imgs, img_metas, proposals, **kwargs):
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(imgs), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
imgs_per_gpu = imgs[0].size(0)
assert imgs_per_gpu == 1
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], proposals[0],
**kwargs)
else:
return self.aug_test(imgs, img_metas, proposals, **kwargs)
...@@ -135,6 +135,11 @@ class MaskTestMixin(object): ...@@ -135,6 +135,11 @@ class MaskTestMixin(object):
ori_shape = img_metas[0][0]['ori_shape'] ori_shape = img_metas[0][0]['ori_shape']
segm_result = self.mask_head.get_seg_masks( segm_result = self.mask_head.get_seg_masks(
merged_masks, det_bboxes, det_labels, self.test_cfg.rcnn, merged_masks,
ori_shape) det_bboxes,
det_labels,
self.test_cfg.rcnn,
ori_shape,
scale_factor=1.0,
rescale=False)
return segm_result return segm_result
...@@ -146,7 +146,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -146,7 +146,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
x = self.extract_feat(img) x = self.extract_feat(img)
proposal_list = self.simple_test_rpn( proposal_list = self.simple_test_rpn(
x, img_meta, self.test_cfg.rpn) if proposals is None else proposals x, img_meta,
self.test_cfg.rpn) if proposals is None else proposals
det_bboxes, det_labels = self.simple_test_bboxes( det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
......
...@@ -3,11 +3,11 @@ import argparse ...@@ -3,11 +3,11 @@ import argparse
import torch import torch
import mmcv import mmcv
from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict
from mmcv.parallel import scatter, MMDataParallel from mmcv.parallel import scatter, collate, MMDataParallel
from mmdet import datasets from mmdet import datasets
from mmdet.core import results2json, coco_eval from mmdet.core import results2json, coco_eval
from mmdet.datasets import collate, build_dataloader from mmdet.datasets import build_dataloader
from mmdet.models import build_detector, detectors from mmdet.models import build_detector, detectors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment