cascade_rcnn.py 23.1 KB
Newer Older
1
2
from __future__ import division

Kai Chen's avatar
Kai Chen committed
3
4
5
import torch
import torch.nn as nn

6
7
8
from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
                        build_sampler, merge_aug_bboxes, merge_aug_masks,
                        multiclass_nms)
Kai Chen's avatar
Kai Chen committed
9
from .. import builder
Kai Chen's avatar
Kai Chen committed
10
from ..registry import DETECTORS
11
12
from .base import BaseDetector
from .test_mixins import RPNTestMixin
Kai Chen's avatar
Kai Chen committed
13
14


Kai Chen's avatar
Kai Chen committed
15
@DETECTORS.register_module
Kai Chen's avatar
Kai Chen committed
16
17
18
19
20
21
class CascadeRCNN(BaseDetector, RPNTestMixin):

    def __init__(self,
                 num_stages,
                 backbone,
                 neck=None,
myownskyW7's avatar
myownskyW7 committed
22
                 shared_head=None,
Kai Chen's avatar
Kai Chen committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
                 rpn_head=None,
                 bbox_roi_extractor=None,
                 bbox_head=None,
                 mask_roi_extractor=None,
                 mask_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        assert bbox_roi_extractor is not None
        assert bbox_head is not None
        super(CascadeRCNN, self).__init__()

        self.num_stages = num_stages
        self.backbone = builder.build_backbone(backbone)

        if neck is not None:
            self.neck = builder.build_neck(neck)

        if rpn_head is not None:
Kai Chen's avatar
Kai Chen committed
42
            self.rpn_head = builder.build_head(rpn_head)
Kai Chen's avatar
Kai Chen committed
43

myownskyW7's avatar
myownskyW7 committed
44
45
46
        if shared_head is not None:
            self.shared_head = builder.build_shared_head(shared_head)

Kai Chen's avatar
Kai Chen committed
47
48
49
50
51
52
53
54
55
56
57
58
59
        if bbox_head is not None:
            self.bbox_roi_extractor = nn.ModuleList()
            self.bbox_head = nn.ModuleList()
            if not isinstance(bbox_roi_extractor, list):
                bbox_roi_extractor = [
                    bbox_roi_extractor for _ in range(num_stages)
                ]
            if not isinstance(bbox_head, list):
                bbox_head = [bbox_head for _ in range(num_stages)]
            assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages
            for roi_extractor, head in zip(bbox_roi_extractor, bbox_head):
                self.bbox_roi_extractor.append(
                    builder.build_roi_extractor(roi_extractor))
Kai Chen's avatar
Kai Chen committed
60
                self.bbox_head.append(builder.build_head(head))
Kai Chen's avatar
Kai Chen committed
61
62
63
64
65

        if mask_head is not None:
            self.mask_head = nn.ModuleList()
            if not isinstance(mask_head, list):
                mask_head = [mask_head for _ in range(num_stages)]
myownskyW7's avatar
myownskyW7 committed
66
67
            assert len(mask_head) == self.num_stages
            for head in mask_head:
Kai Chen's avatar
Kai Chen committed
68
                self.mask_head.append(builder.build_head(head))
myownskyW7's avatar
myownskyW7 committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            if mask_roi_extractor is not None:
                self.share_roi_extractor = False
                self.mask_roi_extractor = nn.ModuleList()
                if not isinstance(mask_roi_extractor, list):
                    mask_roi_extractor = [
                        mask_roi_extractor for _ in range(num_stages)
                    ]
                assert len(mask_roi_extractor) == self.num_stages
                for roi_extractor in mask_roi_extractor:
                    self.mask_roi_extractor.append(
                        builder.build_roi_extractor(roi_extractor))
            else:
                self.share_roi_extractor = True
                self.mask_roi_extractor = self.bbox_roi_extractor
Kai Chen's avatar
Kai Chen committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.init_weights(pretrained=pretrained)

    @property
    def with_rpn(self):
        return hasattr(self, 'rpn_head') and self.rpn_head is not None

    def init_weights(self, pretrained=None):
        super(CascadeRCNN, self).init_weights(pretrained)
        self.backbone.init_weights(pretrained=pretrained)
        if self.with_neck:
            if isinstance(self.neck, nn.Sequential):
                for m in self.neck:
                    m.init_weights()
            else:
                self.neck.init_weights()
        if self.with_rpn:
            self.rpn_head.init_weights()
myownskyW7's avatar
myownskyW7 committed
104
105
        if self.with_shared_head:
            self.shared_head.init_weights(pretrained=pretrained)
Kai Chen's avatar
Kai Chen committed
106
107
108
109
110
        for i in range(self.num_stages):
            if self.with_bbox:
                self.bbox_roi_extractor[i].init_weights()
                self.bbox_head[i].init_weights()
            if self.with_mask:
myownskyW7's avatar
myownskyW7 committed
111
112
                if not self.share_roi_extractor:
                    self.mask_roi_extractor[i].init_weights()
Kai Chen's avatar
Kai Chen committed
113
114
115
116
117
118
119
120
                self.mask_head[i].init_weights()

    def extract_feat(self, img):
        x = self.backbone(img)
        if self.with_neck:
            x = self.neck(x)
        return x

Kai Chen's avatar
Kai Chen committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    def forward_dummy(self, img):
        outs = ()
        # backbone
        x = self.extract_feat(img)
        # rpn
        if self.with_rpn:
            rpn_outs = self.rpn_head(x)
            outs = outs + (rpn_outs, )
        proposals = torch.randn(1000, 4).cuda()
        # bbox heads
        rois = bbox2roi([proposals])
        if self.with_bbox:
            for i in range(self.num_stages):
                bbox_feats = self.bbox_roi_extractor[i](
                    x[:self.bbox_roi_extractor[i].num_inputs], rois)
                if self.with_shared_head:
                    bbox_feats = self.shared_head(bbox_feats)
                cls_score, bbox_pred = self.bbox_head[i](bbox_feats)
                outs = outs + (cls_score, bbox_pred)
        # mask heads
        if self.with_mask:
            mask_rois = rois[:100]
            for i in range(self.num_stages):
                mask_feats = self.mask_roi_extractor[i](
                    x[:self.mask_roi_extractor[i].num_inputs], mask_rois)
                if self.with_shared_head:
                    mask_feats = self.shared_head(mask_feats)
                mask_pred = self.mask_head[i](mask_feats)
                outs = outs + (mask_pred, )
        return outs

Kai Chen's avatar
Kai Chen committed
152
153
154
155
156
    def forward_train(self,
                      img,
                      img_meta,
                      gt_bboxes,
                      gt_labels,
157
                      gt_bboxes_ignore=None,
Kai Chen's avatar
Kai Chen committed
158
159
                      gt_masks=None,
                      proposals=None):
Jon Crall's avatar
Jon Crall committed
160
161
        """
        Args:
162
            img (Tensor): of shape (N, C, H, W) encoding input images.
Jon Crall's avatar
Jon Crall committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
                Typically these should be mean centered and std scaled.

            img_meta (list[dict]): list of image info dict where each dict has:
                'img_shape', 'scale_factor', 'flip', and my also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.

            gt_bboxes (list[Tensor]): each item are the truth boxes for each
                image in [tl_x, tl_y, br_x, br_y] format.

            gt_labels (list[Tensor]): class indices corresponding to each box

            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.

            gt_masks (None | Tensor) : true segmentation masks for each box
                used if the architecture supports a segmentation task.

            proposals : override rpn proposals with custom proposals. Use when
                `with_rpn` is False.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
Kai Chen's avatar
Kai Chen committed
188
189
190
191
192
193
194
195
        x = self.extract_feat(img)

        losses = dict()

        if self.with_rpn:
            rpn_outs = self.rpn_head(x)
            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                          self.train_cfg.rpn)
196
197
            rpn_losses = self.rpn_head.loss(
                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
Kai Chen's avatar
Kai Chen committed
198
199
            losses.update(rpn_losses)

200
201
202
            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)
            proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
203
            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
Kai Chen's avatar
Kai Chen committed
204
205
206
207
        else:
            proposal_list = proposals

        for i in range(self.num_stages):
208
            self.current_stage = i
Kai Chen's avatar
Kai Chen committed
209
            rcnn_train_cfg = self.train_cfg.rcnn[i]
210
            lw = self.train_cfg.stage_loss_weights[i]
Kai Chen's avatar
Kai Chen committed
211
212

            # assign gts and sample proposals
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
            sampling_results = []
            if self.with_bbox or self.with_mask:
                bbox_assigner = build_assigner(rcnn_train_cfg.assigner)
                bbox_sampler = build_sampler(
                    rcnn_train_cfg.sampler, context=self)
                num_imgs = img.size(0)
                if gt_bboxes_ignore is None:
                    gt_bboxes_ignore = [None for _ in range(num_imgs)]

                for j in range(num_imgs):
                    assign_result = bbox_assigner.assign(
                        proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j],
                        gt_labels[j])
                    sampling_result = bbox_sampler.sample(
                        assign_result,
                        proposal_list[j],
                        gt_bboxes[j],
                        gt_labels[j],
                        feats=[lvl_feat[j][None] for lvl_feat in x])
                    sampling_results.append(sampling_result)
Kai Chen's avatar
Kai Chen committed
233
234
235
236
237
238
239
240

            # bbox head forward and loss
            bbox_roi_extractor = self.bbox_roi_extractor[i]
            bbox_head = self.bbox_head[i]

            rois = bbox2roi([res.bboxes for res in sampling_results])
            bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
                                            rois)
myownskyW7's avatar
myownskyW7 committed
241
242
            if self.with_shared_head:
                bbox_feats = self.shared_head(bbox_feats)
Kai Chen's avatar
Kai Chen committed
243
244
245
246
247
248
            cls_score, bbox_pred = bbox_head(bbox_feats)

            bbox_targets = bbox_head.get_target(sampling_results, gt_bboxes,
                                                gt_labels, rcnn_train_cfg)
            loss_bbox = bbox_head.loss(cls_score, bbox_pred, *bbox_targets)
            for name, value in loss_bbox.items():
249
250
                losses['s{}.{}'.format(i, name)] = (
                    value * lw if 'loss' in name else value)
Kai Chen's avatar
Kai Chen committed
251
252
253

            # mask head forward and loss
            if self.with_mask:
myownskyW7's avatar
myownskyW7 committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
                if not self.share_roi_extractor:
                    mask_roi_extractor = self.mask_roi_extractor[i]
                    pos_rois = bbox2roi(
                        [res.pos_bboxes for res in sampling_results])
                    mask_feats = mask_roi_extractor(
                        x[:mask_roi_extractor.num_inputs], pos_rois)
                    if self.with_shared_head:
                        mask_feats = self.shared_head(mask_feats)
                else:
                    # reuse positive bbox feats
                    pos_inds = []
                    device = bbox_feats.device
                    for res in sampling_results:
                        pos_inds.append(
                            torch.ones(
                                res.pos_bboxes.shape[0],
                                device=device,
                                dtype=torch.uint8))
                        pos_inds.append(
                            torch.zeros(
                                res.neg_bboxes.shape[0],
                                device=device,
                                dtype=torch.uint8))
                    pos_inds = torch.cat(pos_inds)
                    mask_feats = bbox_feats[pos_inds]
Kai Chen's avatar
Kai Chen committed
279
280
281
282
283
284
285
286
                mask_head = self.mask_head[i]
                mask_pred = mask_head(mask_feats)
                mask_targets = mask_head.get_target(sampling_results, gt_masks,
                                                    rcnn_train_cfg)
                pos_labels = torch.cat(
                    [res.pos_gt_labels for res in sampling_results])
                loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels)
                for name, value in loss_mask.items():
287
288
                    losses['s{}.{}'.format(i, name)] = (
                        value * lw if 'loss' in name else value)
Kai Chen's avatar
Kai Chen committed
289
290
291
292
293
294
295
296
297
298
299
300

            # refine bboxes
            if i < self.num_stages - 1:
                pos_is_gts = [res.pos_is_gt for res in sampling_results]
                roi_labels = bbox_targets[0]  # bbox_targets is a tuple
                with torch.no_grad():
                    proposal_list = bbox_head.refine_bboxes(
                        rois, roi_labels, bbox_pred, pos_is_gts, img_meta)

        return losses

    def simple_test(self, img, img_meta, proposals=None, rescale=False):
Jon Crall's avatar
Jon Crall committed
301
302
303
        """Run inference on a single image.

        Args:
304
            img (Tensor): must be in shape (N, C, H, W)
Jon Crall's avatar
Jon Crall committed
305
306
307
308
309
310
311
312
313
            img_meta (list[dict]): a list with one dictionary element.
                See `mmdet/datasets/pipelines/formatting.py:Collect` for
                details of meta dicts.
            proposals : if specified overrides rpn proposals
            rescale (bool): if True returns boxes in original image space

        Returns:
            dict: results
        """
Kai Chen's avatar
Kai Chen committed
314
        x = self.extract_feat(img)
Jon Crall's avatar
Jon Crall committed
315

Kai Chen's avatar
Kai Chen committed
316
317
318
319
320
321
322
323
        proposal_list = self.simple_test_rpn(
            x, img_meta, self.test_cfg.rpn) if proposals is None else proposals

        img_shape = img_meta[0]['img_shape']
        ori_shape = img_meta[0]['ori_shape']
        scale_factor = img_meta[0]['scale_factor']

        # "ms" in variable names means multi-stage
324
325
        ms_bbox_result = {}
        ms_segm_result = {}
Kai Chen's avatar
Kai Chen committed
326
327
328
329
330
331
332
333
334
335
        ms_scores = []
        rcnn_test_cfg = self.test_cfg.rcnn

        rois = bbox2roi(proposal_list)
        for i in range(self.num_stages):
            bbox_roi_extractor = self.bbox_roi_extractor[i]
            bbox_head = self.bbox_head[i]

            bbox_feats = bbox_roi_extractor(
                x[:len(bbox_roi_extractor.featmap_strides)], rois)
myownskyW7's avatar
myownskyW7 committed
336
337
338
            if self.with_shared_head:
                bbox_feats = self.shared_head(bbox_feats)

Kai Chen's avatar
Kai Chen committed
339
340
341
342
343
344
345
346
347
348
349
            cls_score, bbox_pred = bbox_head(bbox_feats)
            ms_scores.append(cls_score)

            if self.test_cfg.keep_all_stages:
                det_bboxes, det_labels = bbox_head.get_det_bboxes(
                    rois,
                    cls_score,
                    bbox_pred,
                    img_shape,
                    scale_factor,
                    rescale=rescale,
350
                    cfg=rcnn_test_cfg)
Kai Chen's avatar
Kai Chen committed
351
352
                bbox_result = bbox2result(det_bboxes, det_labels,
                                          bbox_head.num_classes)
353
                ms_bbox_result['stage{}'.format(i)] = bbox_result
Kai Chen's avatar
Kai Chen committed
354
355

                if self.with_mask:
356
357
                    mask_roi_extractor = self.mask_roi_extractor[i]
                    mask_head = self.mask_head[i]
Kai Chen's avatar
Kai Chen committed
358
                    if det_bboxes.shape[0] == 0:
359
360
                        mask_classes = mask_head.num_classes - 1
                        segm_result = [[] for _ in range(mask_classes)]
Kai Chen's avatar
Kai Chen committed
361
                    else:
362
                        _bboxes = (
363
364
                            det_bboxes[:, :4] *
                            scale_factor if rescale else det_bboxes)
Kai Chen's avatar
Kai Chen committed
365
                        mask_rois = bbox2roi([_bboxes])
366
367
368
                        mask_feats = mask_roi_extractor(
                            x[:len(mask_roi_extractor.featmap_strides)],
                            mask_rois)
myownskyW7's avatar
myownskyW7 committed
369
370
                        if self.with_shared_head:
                            mask_feats = self.shared_head(mask_feats, i)
Kai Chen's avatar
Kai Chen committed
371
372
373
374
                        mask_pred = mask_head(mask_feats)
                        segm_result = mask_head.get_seg_masks(
                            mask_pred, _bboxes, det_labels, rcnn_test_cfg,
                            ori_shape, scale_factor, rescale)
375
                    ms_segm_result['stage{}'.format(i)] = segm_result
Kai Chen's avatar
Kai Chen committed
376
377
378
379
380
381

            if i < self.num_stages - 1:
                bbox_label = cls_score.argmax(dim=1)
                rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
                                                  img_meta[0])

Kai Chen's avatar
Kai Chen committed
382
        cls_score = sum(ms_scores) / self.num_stages
Kai Chen's avatar
Kai Chen committed
383
384
385
386
387
388
389
        det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes(
            rois,
            cls_score,
            bbox_pred,
            img_shape,
            scale_factor,
            rescale=rescale,
390
            cfg=rcnn_test_cfg)
Kai Chen's avatar
Kai Chen committed
391
392
        bbox_result = bbox2result(det_bboxes, det_labels,
                                  self.bbox_head[-1].num_classes)
393
        ms_bbox_result['ensemble'] = bbox_result
Kai Chen's avatar
Kai Chen committed
394
395

        if self.with_mask:
Kai Chen's avatar
Kai Chen committed
396
            if det_bboxes.shape[0] == 0:
397
398
                mask_classes = self.mask_head[-1].num_classes - 1
                segm_result = [[] for _ in range(mask_classes)]
Kai Chen's avatar
Kai Chen committed
399
            else:
liushuchun's avatar
liushuchun committed
400
401
402
403
404
405
406
407
408
409
                if isinstance(scale_factor, float):  # aspect ratio fixed
                    _bboxes = (
                        det_bboxes[:, :4] *
                        scale_factor if rescale else det_bboxes)
                else:
                    _bboxes = (
                        det_bboxes[:, :4] *
                        torch.from_numpy(scale_factor).to(det_bboxes.device)
                        if rescale else det_bboxes)

Kai Chen's avatar
Kai Chen committed
410
411
412
413
414
415
                mask_rois = bbox2roi([_bboxes])
                aug_masks = []
                for i in range(self.num_stages):
                    mask_roi_extractor = self.mask_roi_extractor[i]
                    mask_feats = mask_roi_extractor(
                        x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
myownskyW7's avatar
myownskyW7 committed
416
417
                    if self.with_shared_head:
                        mask_feats = self.shared_head(mask_feats)
Kai Chen's avatar
Kai Chen committed
418
419
420
                    mask_pred = self.mask_head[i](mask_feats)
                    aug_masks.append(mask_pred.sigmoid().cpu().numpy())
                merged_masks = merge_aug_masks(aug_masks,
Kai Chen's avatar
bug fix  
Kai Chen committed
421
                                               [img_meta] * self.num_stages,
Kai Chen's avatar
Kai Chen committed
422
423
424
425
                                               self.test_cfg.rcnn)
                segm_result = self.mask_head[-1].get_seg_masks(
                    merged_masks, _bboxes, det_labels, rcnn_test_cfg,
                    ori_shape, scale_factor, rescale)
426
            ms_segm_result['ensemble'] = segm_result
Kai Chen's avatar
Kai Chen committed
427

Kai Chen's avatar
Kai Chen committed
428
429
        if not self.test_cfg.keep_all_stages:
            if self.with_mask:
Kai Chen's avatar
Kai Chen committed
430
431
432
433
                results = (ms_bbox_result['ensemble'],
                           ms_segm_result['ensemble'])
            else:
                results = ms_bbox_result['ensemble']
Kai Chen's avatar
Kai Chen committed
434
        else:
Kai Chen's avatar
Kai Chen committed
435
436
437
438
439
440
441
442
443
            if self.with_mask:
                results = {
                    stage: (ms_bbox_result[stage], ms_segm_result[stage])
                    for stage in ms_bbox_result
                }
            else:
                results = ms_bbox_result

        return results
Kai Chen's avatar
Kai Chen committed
444

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
    def aug_test(self, imgs, img_metas, proposals=None, rescale=False):
        """Test with augmentations.

        If rescale is False, then returned bboxes and masks will fit the scale
        of imgs[0].
        """
        # recompute feats to save memory
        proposal_list = self.aug_test_rpn(
            self.extract_feats(imgs), img_metas, self.test_cfg.rpn)

        rcnn_test_cfg = self.test_cfg.rcnn
        aug_bboxes = []
        aug_scores = []
        for x, img_meta in zip(self.extract_feats(imgs), img_metas):
            # only one image in the batch
            img_shape = img_meta[0]['img_shape']
            scale_factor = img_meta[0]['scale_factor']
            flip = img_meta[0]['flip']

            proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
                                     scale_factor, flip)
            # "ms" in variable names means multi-stage
            ms_scores = []

            rois = bbox2roi([proposals])
            for i in range(self.num_stages):
                bbox_roi_extractor = self.bbox_roi_extractor[i]
                bbox_head = self.bbox_head[i]

                bbox_feats = bbox_roi_extractor(
                    x[:len(bbox_roi_extractor.featmap_strides)], rois)
                if self.with_shared_head:
                    bbox_feats = self.shared_head(bbox_feats)

                cls_score, bbox_pred = bbox_head(bbox_feats)
                ms_scores.append(cls_score)

                if i < self.num_stages - 1:
                    bbox_label = cls_score.argmax(dim=1)
                    rois = bbox_head.regress_by_class(rois, bbox_label,
                                                      bbox_pred, img_meta[0])

            cls_score = sum(ms_scores) / float(len(ms_scores))
            bboxes, scores = self.bbox_head[-1].get_det_bboxes(
                rois,
                cls_score,
                bbox_pred,
                img_shape,
                scale_factor,
                rescale=False,
                cfg=None)
            aug_bboxes.append(bboxes)
            aug_scores.append(scores)

        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes, merged_scores = merge_aug_bboxes(
            aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
        det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
                                                rcnn_test_cfg.score_thr,
                                                rcnn_test_cfg.nms,
                                                rcnn_test_cfg.max_per_img)

        bbox_result = bbox2result(det_bboxes, det_labels,
                                  self.bbox_head[-1].num_classes)

        if self.with_mask:
            if det_bboxes.shape[0] == 0:
                segm_result = [[]
                               for _ in range(self.mask_head[-1].num_classes -
                                              1)]
            else:
                aug_masks = []
                aug_img_metas = []
                for x, img_meta in zip(self.extract_feats(imgs), img_metas):
                    img_shape = img_meta[0]['img_shape']
                    scale_factor = img_meta[0]['scale_factor']
                    flip = img_meta[0]['flip']
                    _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
                                           scale_factor, flip)
                    mask_rois = bbox2roi([_bboxes])
                    for i in range(self.num_stages):
                        mask_feats = self.mask_roi_extractor[i](
                            x[:len(self.mask_roi_extractor[i].featmap_strides
                                   )], mask_rois)
                        if self.with_shared_head:
                            mask_feats = self.shared_head(mask_feats)
                        mask_pred = self.mask_head[i](mask_feats)
                        aug_masks.append(mask_pred.sigmoid().cpu().numpy())
                        aug_img_metas.append(img_meta)
                merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
                                               self.test_cfg.rcnn)

                ori_shape = img_metas[0][0]['ori_shape']
                segm_result = self.mask_head[-1].get_seg_masks(
                    merged_masks,
                    det_bboxes,
                    det_labels,
                    rcnn_test_cfg,
                    ori_shape,
                    scale_factor=1.0,
                    rescale=False)
            return bbox_result, segm_result
        else:
            return bbox_result
Kai Chen's avatar
Kai Chen committed
549

550
    def show_result(self, data, result, **kwargs):
Kai Chen's avatar
Kai Chen committed
551
552
        if self.with_mask:
            ms_bbox_result, ms_segm_result = result
Kai Chen's avatar
Kai Chen committed
553
554
555
            if isinstance(ms_bbox_result, dict):
                result = (ms_bbox_result['ensemble'],
                          ms_segm_result['ensemble'])
Kai Chen's avatar
Kai Chen committed
556
        else:
Kai Chen's avatar
Kai Chen committed
557
558
            if isinstance(result, dict):
                result = result['ensemble']
559
        super(CascadeRCNN, self).show_result(data, result, **kwargs)