Unverified Commit afe5ce0a authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #1 from OceanPang/dev

faster-rcnn & mask-rcnn train and test support 
parents 0401cccd 782ba019
......@@ -93,11 +93,13 @@ class FCNMaskHead(nn.Module):
return mask_targets
def loss(self, mask_pred, mask_targets, labels):
loss = dict()
loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
return loss_mask
loss['loss_mask'] = loss_mask
return loss
def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
ori_scale):
ori_shape):
"""Get segmentation masks from mask_pred and bboxes
Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
......@@ -108,7 +110,7 @@ class FCNMaskHead(nn.Module):
det_labels (Tensor): shape (n, )
img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config
rescale (bool): whether rescale masks to original image size
ori_shape: original image size
Returns:
list[list]: encoded masks
"""
......@@ -118,8 +120,8 @@ class FCNMaskHead(nn.Module):
cls_segms = [[] for _ in range(self.num_classes - 1)]
bboxes = det_bboxes.cpu().numpy()[:, :4]
labels = det_labels.cpu().numpy() + 1
img_h = ori_scale[0]
img_w = ori_scale[1]
img_h = ori_shape[0]
img_w = ori_shape[1]
for i in range(bboxes.shape[0]):
bbox = bboxes[i, :].astype(int)
......
......@@ -4,6 +4,7 @@ import torch
import torch.nn as nn
from mmdet import ops
from mmdet.core import bbox_assign, bbox_sampling
class SingleLevelRoI(nn.Module):
......@@ -51,6 +52,36 @@ class SingleLevelRoI(nn.Module):
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls
def sample_proposals(self, proposals, gt_bboxes, gt_crowds, gt_labels,
cfg):
proposals = proposals[:, :4]
assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
bbox_assign(proposals, gt_bboxes, gt_crowds, gt_labels,
cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.pos_iou_thr, cfg.crowd_thr)
if cfg.add_gt_as_proposals:
proposals = torch.cat([gt_bboxes, proposals], dim=0)
gt_assign_self = torch.arange(
1,
len(gt_labels) + 1,
dtype=torch.long,
device=proposals.device)
assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
assigned_labels = torch.cat([gt_labels, assigned_labels])
pos_inds, neg_inds = bbox_sampling(
assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction,
cfg.neg_pos_ub, cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
pos_proposals = proposals[pos_inds]
neg_proposals = proposals[neg_inds]
pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
pos_gt_labels = assigned_labels[pos_inds]
return (pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes, pos_gt_labels)
def forward(self, feats, rois):
"""Extract roi features with the roi layer. If multiple feature levels
are used, then rois are mapped to corresponding levels according to
......
......@@ -90,7 +90,11 @@ data = dict(
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5),
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True,
test_mode=False),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
......@@ -98,7 +102,10 @@ data = dict(
img_scale=(1333, 800),
flip_ratio=0,
img_norm_cfg=img_norm_cfg,
size_divisor=32))
size_divisor=32,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
......@@ -112,7 +119,7 @@ lr_config = dict(
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log')
......@@ -120,7 +127,8 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
device_ids = range(8)
dist_params = dict(backend='nccl', port='29500')
log_level = 'INFO'
work_dir = './work_dirs/fpn_faster_rcnn_r50_1x'
load_from = None
......
......@@ -103,7 +103,11 @@ data = dict(
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5),
flip_ratio=0.5,
with_mask=True,
with_crowd=True,
with_label=True,
test_mode=False),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
......@@ -111,7 +115,10 @@ data = dict(
img_scale=(1333, 800),
flip_ratio=0,
img_norm_cfg=img_norm_cfg,
size_divisor=32))
size_divisor=32,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
......@@ -120,12 +127,12 @@ lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.333,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')),
......@@ -133,7 +140,8 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
device_ids = range(8)
dist_params = dict(backend='nccl', port='29500')
log_level = 'INFO'
work_dir = './work_dirs/fpn_mask_rcnn_r50_1x'
load_from = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment