Unverified Commit afe5ce0a authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #1 from OceanPang/dev

faster-rcnn & mask-rcnn train and test support 
parents 0401cccd 782ba019
...@@ -93,11 +93,13 @@ class FCNMaskHead(nn.Module): ...@@ -93,11 +93,13 @@ class FCNMaskHead(nn.Module):
return mask_targets return mask_targets
def loss(self, mask_pred, mask_targets, labels): def loss(self, mask_pred, mask_targets, labels):
loss = dict()
loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels) loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
return loss_mask loss['loss_mask'] = loss_mask
return loss
def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
ori_scale): ori_shape):
"""Get segmentation masks from mask_pred and bboxes """Get segmentation masks from mask_pred and bboxes
Args: Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w). mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
...@@ -108,7 +110,7 @@ class FCNMaskHead(nn.Module): ...@@ -108,7 +110,7 @@ class FCNMaskHead(nn.Module):
det_labels (Tensor): shape (n, ) det_labels (Tensor): shape (n, )
img_shape (Tensor): shape (3, ) img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config rcnn_test_cfg (dict): rcnn testing config
rescale (bool): whether rescale masks to original image size ori_shape: original image size
Returns: Returns:
list[list]: encoded masks list[list]: encoded masks
""" """
...@@ -118,8 +120,8 @@ class FCNMaskHead(nn.Module): ...@@ -118,8 +120,8 @@ class FCNMaskHead(nn.Module):
cls_segms = [[] for _ in range(self.num_classes - 1)] cls_segms = [[] for _ in range(self.num_classes - 1)]
bboxes = det_bboxes.cpu().numpy()[:, :4] bboxes = det_bboxes.cpu().numpy()[:, :4]
labels = det_labels.cpu().numpy() + 1 labels = det_labels.cpu().numpy() + 1
img_h = ori_scale[0] img_h = ori_shape[0]
img_w = ori_scale[1] img_w = ori_shape[1]
for i in range(bboxes.shape[0]): for i in range(bboxes.shape[0]):
bbox = bboxes[i, :].astype(int) bbox = bboxes[i, :].astype(int)
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmdet import ops from mmdet import ops
from mmdet.core import bbox_assign, bbox_sampling
class SingleLevelRoI(nn.Module): class SingleLevelRoI(nn.Module):
...@@ -51,6 +52,36 @@ class SingleLevelRoI(nn.Module): ...@@ -51,6 +52,36 @@ class SingleLevelRoI(nn.Module):
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls return target_lvls
def sample_proposals(self, proposals, gt_bboxes, gt_crowds, gt_labels,
cfg):
proposals = proposals[:, :4]
assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
bbox_assign(proposals, gt_bboxes, gt_crowds, gt_labels,
cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.pos_iou_thr, cfg.crowd_thr)
if cfg.add_gt_as_proposals:
proposals = torch.cat([gt_bboxes, proposals], dim=0)
gt_assign_self = torch.arange(
1,
len(gt_labels) + 1,
dtype=torch.long,
device=proposals.device)
assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
assigned_labels = torch.cat([gt_labels, assigned_labels])
pos_inds, neg_inds = bbox_sampling(
assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction,
cfg.neg_pos_ub, cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
pos_proposals = proposals[pos_inds]
neg_proposals = proposals[neg_inds]
pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
pos_gt_labels = assigned_labels[pos_inds]
return (pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes, pos_gt_labels)
def forward(self, feats, rois): def forward(self, feats, rois):
"""Extract roi features with the roi layer. If multiple feature levels """Extract roi features with the roi layer. If multiple feature levels
are used, then rois are mapped to corresponding levels according to are used, then rois are mapped to corresponding levels according to
......
...@@ -90,7 +90,11 @@ data = dict( ...@@ -90,7 +90,11 @@ data = dict(
img_scale=(1333, 800), img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg, img_norm_cfg=img_norm_cfg,
size_divisor=32, size_divisor=32,
flip_ratio=0.5), flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True,
test_mode=False),
test=dict( test=dict(
type=dataset_type, type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json', ann_file=data_root + 'annotations/instances_val2017.json',
...@@ -98,7 +102,10 @@ data = dict( ...@@ -98,7 +102,10 @@ data = dict(
img_scale=(1333, 800), img_scale=(1333, 800),
flip_ratio=0, flip_ratio=0,
img_norm_cfg=img_norm_cfg, img_norm_cfg=img_norm_cfg,
size_divisor=32)) size_divisor=32,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer # optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
...@@ -112,7 +119,7 @@ lr_config = dict( ...@@ -112,7 +119,7 @@ lr_config = dict(
checkpoint_config = dict(interval=1) checkpoint_config = dict(interval=1)
# yapf:disable # yapf:disable
log_config = dict( log_config = dict(
interval=50, interval=20,
hooks=[ hooks=[
dict(type='TextLoggerHook'), dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log') # dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log')
...@@ -120,7 +127,8 @@ log_config = dict( ...@@ -120,7 +127,8 @@ log_config = dict(
# yapf:enable # yapf:enable
# runtime settings # runtime settings
total_epochs = 12 total_epochs = 12
dist_params = dict(backend='nccl') device_ids = range(8)
dist_params = dict(backend='nccl', port='29500')
log_level = 'INFO' log_level = 'INFO'
work_dir = './work_dirs/fpn_faster_rcnn_r50_1x' work_dir = './work_dirs/fpn_faster_rcnn_r50_1x'
load_from = None load_from = None
......
...@@ -103,7 +103,11 @@ data = dict( ...@@ -103,7 +103,11 @@ data = dict(
img_scale=(1333, 800), img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg, img_norm_cfg=img_norm_cfg,
size_divisor=32, size_divisor=32,
flip_ratio=0.5), flip_ratio=0.5,
with_mask=True,
with_crowd=True,
with_label=True,
test_mode=False),
test=dict( test=dict(
type=dataset_type, type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json', ann_file=data_root + 'annotations/instances_val2017.json',
...@@ -111,7 +115,10 @@ data = dict( ...@@ -111,7 +115,10 @@ data = dict(
img_scale=(1333, 800), img_scale=(1333, 800),
flip_ratio=0, flip_ratio=0,
img_norm_cfg=img_norm_cfg, img_norm_cfg=img_norm_cfg,
size_divisor=32)) size_divisor=32,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer # optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
...@@ -120,12 +127,12 @@ lr_config = dict( ...@@ -120,12 +127,12 @@ lr_config = dict(
policy='step', policy='step',
warmup='linear', warmup='linear',
warmup_iters=500, warmup_iters=500,
warmup_ratio=0.333, warmup_ratio=1.0 / 3,
step=[8, 11]) step=[8, 11])
checkpoint_config = dict(interval=1) checkpoint_config = dict(interval=1)
# yapf:disable # yapf:disable
log_config = dict( log_config = dict(
interval=50, interval=20,
hooks=[ hooks=[
dict(type='TextLoggerHook'), dict(type='TextLoggerHook'),
# ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')), # ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')),
...@@ -133,7 +140,8 @@ log_config = dict( ...@@ -133,7 +140,8 @@ log_config = dict(
# yapf:enable # yapf:enable
# runtime settings # runtime settings
total_epochs = 12 total_epochs = 12
dist_params = dict(backend='nccl') device_ids = range(8)
dist_params = dict(backend='nccl', port='29500')
log_level = 'INFO' log_level = 'INFO'
work_dir = './work_dirs/fpn_mask_rcnn_r50_1x' work_dir = './work_dirs/fpn_mask_rcnn_r50_1x'
load_from = None load_from = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment