Commit 6ce7cce1 authored by Tao Kong's avatar Tao Kong Committed by Kai Chen
Browse files

FoveaBox support (#1339)

* add foveabox support

* format

* fix

* foveabox correct

* fix mean and std in config

* add foveabox in model zoo

* fix work_dir, typo, fovea_target and config issues

* add necessary comments

* fix yapf issue in C1 check

* fix yapf issue in C1 check

* rename configs

* update model urls and minor fix
parent 0e9de200
# FoveaBox: Beyond Anchor-based Object Detector
FoveaBox is an accurate, flexible and completely anchor-free object detection system for object detection framework, as presented in our paper [https://arxiv.org/abs/1904.03797](https://arxiv.org/abs/1904.03797):
Different from previous anchor-based methods, FoveaBox directly learns the object existing possibility and the bounding box coordinates without anchor reference. This is achieved by: (a) predicting category-sensitive semantic maps for the object existing possibility, and (b) producing category-agnostic bounding box for each position that potentially contains an object.
## Main Results
### Results on R50/101-FPN
| Backbone | Style | align | ms-train| Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
|:---------:|:-------:|:-------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
| R-50 | pytorch | N | N | 1x | 5.7 | - | | 36.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/foveabox/fovea_r50_fpn_4gpu_1x_20190905-3b185a5d.pth) |
| R-50 | pytorch | N | N | 2x | - | - | | 36.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/foveabox/fovea_r50_fpn_4gpu_2x_20190905-4a07f6e0.pth) |
| R-50 | pytorch | Y | N | 2x | - | - | | 37.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/foveabox/fovea_align_gn_r50_fpn_4gpu_2x_20190905-3e6bc82f.pth) |
| R-50 | pytorch | Y | Y | 2x | - | - | | 40.1 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/foveabox/fovea_align_gn_ms_r50_fpn_4gpu_2x_20190905-13374f33.pth) |
| R-101 | pytorch | N | N | 1x | 9.4 | - | | 38.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/foveabox/fovea_r101_fpn_4gpu_1x_20190905-80ff93a6.pth) |
| R-101 | pytorch | N | N | 2x | - | - | | 38.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/foveabox/fovea_r101_fpn_4gpu_2x_20190905-d9c99fb1.pth) |
| R-101 | pytorch | Y | N | 2x | - | - | | 39.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/foveabox/fovea_align_gn_r101_fpn_4gpu_2x_20190905-407ddad6.pth) |
| R-101 | pytorch | Y | Y | 2x | - | - | | 41.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/foveabox/fovea_align_gn_ms_r101_fpn_4gpu_2x_20190905-936c7277.pth) |
[1] *1x and 2x mean the model is trained for 12 and 24 epochs, respectively.* \
[2] *Align means utilizing deformable convolution to align the cls branch.* \
[3] *All results are obtained with a single model and without any test time data augmentation.*\
[4] *We use 4 NVIDIA Tesla V100 GPUs for training.*
Any pull requests or issues are welcome.
## Citations
Please consider citing our paper in your publications if the project helps your research. BibTeX reference is as follows.
```
@article{kong2019foveabox,
title={FoveaBox: Beyond Anchor-based Object Detector},
author={Kong, Tao and Sun, Fuchun and Liu, Huaping and Jiang, Yuning and Shi, Jianbo},
journal={arXiv preprint arXiv:1904.03797},
year={2019}
}
```
\ No newline at end of file
# model settings
model = dict(
type='FOVEA',
pretrained='torchvision://resnet101',
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
num_outs=5,
add_extra_convs=True),
bbox_head=dict(
type='FoveaHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
base_edge_list=[16, 32, 64, 128, 256],
scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)),
sigma=0.4,
with_deform=True,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=1.50,
alpha=0.4,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0)))
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=1000,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100)
# dataset settings
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(1333, 640), (1333, 800)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[16, 22])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 24
device_ids = range(4)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fovea_align_gn_ms_r101_fpn_4gpu_2x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='FOVEA',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
num_outs=5,
add_extra_convs=True),
bbox_head=dict(
type='FoveaHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
base_edge_list=[16, 32, 64, 128, 256],
scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)),
sigma=0.4,
with_deform=True,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=1.50,
alpha=0.4,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0)))
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=1000,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100)
# dataset settings
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(1333, 640), (1333, 800)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[16, 22])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 24
device_ids = range(4)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fovea_align_gn_ms_r50_fpn_4gpu_2x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='FOVEA',
pretrained='torchvision://resnet101',
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
num_outs=5,
add_extra_convs=True),
bbox_head=dict(
type='FoveaHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
base_edge_list=[16, 32, 64, 128, 256],
scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)),
sigma=0.4,
with_deform=True,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=1.50,
alpha=0.4,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0)))
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=1000,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[16, 22])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 24
device_ids = range(4)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fovea_align_gn_r101_fpn_4gpu_2x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='FOVEA',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
num_outs=5,
add_extra_convs=True),
bbox_head=dict(
type='FoveaHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
base_edge_list=[16, 32, 64, 128, 256],
scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)),
sigma=0.4,
with_deform=True,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=1.50,
alpha=0.4,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0)))
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=1000,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[16, 22])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 24
device_ids = range(4)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fovea_align_gn_r50_fpn_4gpu_2x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='FOVEA',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
num_outs=5,
add_extra_convs=True),
bbox_head=dict(
type='FoveaHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
base_edge_list=[16, 32, 64, 128, 256],
scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)),
sigma=0.4,
with_deform=False,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=1.50,
alpha=0.4,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0)))
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=1000,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100)
# dataset settings
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(4)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fovea_r50_fpn_4gpu_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
...@@ -242,6 +242,10 @@ Please refer to [Guided Anchoring](../configs/guided_anchoring/README.md) for de ...@@ -242,6 +242,10 @@ Please refer to [Guided Anchoring](../configs/guided_anchoring/README.md) for de
Please refer to [FCOS](../configs/fcos/README.md) for details. Please refer to [FCOS](../configs/fcos/README.md) for details.
### FoveaBox
Please refer to [FoveaBox](../configs/foveabox/README.md) for details.
### Grid R-CNN (plus) ### Grid R-CNN (plus)
Please refer to [Grid R-CNN](../configs/grid_rcnn/README.md) for details. Please refer to [Grid R-CNN](../configs/grid_rcnn/README.md) for details.
......
from .anchor_head import AnchorHead from .anchor_head import AnchorHead
from .fcos_head import FCOSHead from .fcos_head import FCOSHead
from .fovea_head import FoveaHead
from .ga_retina_head import GARetinaHead from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead from .ga_rpn_head import GARPNHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
...@@ -11,5 +12,5 @@ from .ssd_head import SSDHead ...@@ -11,5 +12,5 @@ from .ssd_head import SSDHead
__all__ = [ __all__ = [
'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
'GARPNHead', 'RetinaHead', 'GARetinaHead', 'SSDHead', 'FCOSHead', 'GARPNHead', 'RetinaHead', 'GARetinaHead', 'SSDHead', 'FCOSHead',
'RepPointsHead' 'RepPointsHead', 'FoveaHead'
] ]
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import multi_apply, multiclass_nms
from mmdet.ops import DeformConv
from ..builder import build_loss
from ..registry import HEADS
from ..utils import ConvModule, bias_init_with_prob
INF = 1e8
class FeatureAlign(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
deformable_groups=4):
super(FeatureAlign, self).__init__()
offset_channels = kernel_size * kernel_size * 2
self.conv_offset = nn.Conv2d(
4, deformable_groups * offset_channels, 1, bias=False)
self.conv_adaption = DeformConv(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
deformable_groups=deformable_groups)
self.relu = nn.ReLU(inplace=True)
def init_weights(self):
normal_init(self.conv_offset, std=0.1)
normal_init(self.conv_adaption, std=0.01)
def forward(self, x, shape):
offset = self.conv_offset(shape)
x = self.relu(self.conv_adaption(x, offset))
return x
@HEADS.register_module
class FoveaHead(nn.Module):
"""FoveaBox: Beyond Anchor-based Object Detector
https://arxiv.org/abs/1904.03797
"""
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128,
512)),
sigma=0.4,
with_deform=False,
deformable_groups=4,
loss_cls=None,
loss_bbox=None,
conv_cfg=None,
norm_cfg=None):
super(FoveaHead, self).__init__()
self.num_classes = num_classes
self.cls_out_channels = num_classes - 1
self.in_channels = in_channels
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.base_edge_list = base_edge_list
self.scale_ranges = scale_ranges
self.sigma = sigma
self.with_deform = with_deform
self.deformable_groups = deformable_groups
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self._init_layers()
def _init_layers(self):
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
# box branch
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
self.fovea_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
# cls branch
if not self.with_deform:
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
self.fovea_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
else:
self.cls_convs.append(
ConvModule(
self.feat_channels, (self.feat_channels * 4),
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
self.cls_convs.append(
ConvModule((self.feat_channels * 4), (self.feat_channels * 4),
1,
stride=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
self.feature_adaption = FeatureAlign(
self.feat_channels,
self.feat_channels,
kernel_size=3,
deformable_groups=self.deformable_groups)
self.fovea_cls = nn.Conv2d(
int(self.feat_channels * 4),
self.cls_out_channels,
3,
padding=1)
def init_weights(self):
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.fovea_cls, std=0.01, bias=bias_cls)
normal_init(self.fovea_reg, std=0.01)
if self.with_deform:
self.feature_adaption.init_weights()
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def forward_single(self, x):
cls_feat = x
reg_feat = x
for reg_layer in self.reg_convs:
reg_feat = reg_layer(reg_feat)
bbox_pred = self.fovea_reg(reg_feat)
if self.with_deform:
cls_feat = self.feature_adaption(cls_feat, bbox_pred.exp())
for cls_layer in self.cls_convs:
cls_feat = cls_layer(cls_feat)
cls_score = self.fovea_cls(cls_feat)
return cls_score, bbox_pred
def get_points(self, featmap_sizes, dtype, device, flatten=False):
points = []
for featmap_size in featmap_sizes:
x_range = torch.arange(
featmap_size[1], dtype=dtype, device=device) + 0.5
y_range = torch.arange(
featmap_size[0], dtype=dtype, device=device) + 0.5
y, x = torch.meshgrid(y_range, x_range)
if flatten:
points.append((y.flatten(), x.flatten()))
else:
points.append((y, x))
return points
def loss(self,
cls_scores,
bbox_preds,
gt_bbox_list,
gt_label_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
assert len(cls_scores) == len(bbox_preds)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
num_imgs = cls_scores[0].size(0)
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
for bbox_pred in bbox_preds
]
flatten_cls_scores = torch.cat(flatten_cls_scores)
flatten_bbox_preds = torch.cat(flatten_bbox_preds)
flatten_labels, flatten_bbox_targets = self.fovea_target(
gt_bbox_list, gt_label_list, featmap_sizes, points)
pos_inds = (flatten_labels > 0).nonzero().view(-1)
num_pos = len(pos_inds)
loss_cls = self.loss_cls(
flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs)
if num_pos > 0:
pos_bbox_preds = flatten_bbox_preds[pos_inds]
pos_bbox_targets = flatten_bbox_targets[pos_inds]
pos_weights = pos_bbox_targets.new_zeros(
pos_bbox_targets.size()) + 1.0
loss_bbox = self.loss_bbox(
pos_bbox_preds,
pos_bbox_targets,
pos_weights,
avg_factor=num_pos)
else:
loss_bbox = torch.tensor([0],
dtype=flatten_bbox_preds.dtype,
device=flatten_bbox_preds.device)
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
def fovea_target(self, gt_bbox_list, gt_label_list, featmap_sizes, points):
label_list, bbox_target_list = multi_apply(
self.fovea_target_single,
gt_bbox_list,
gt_label_list,
featmap_size_list=featmap_sizes,
point_list=points)
flatten_labels = [
torch.cat([
labels_level_img.flatten() for labels_level_img in labels_level
]) for labels_level in zip(*label_list)
]
flatten_bbox_targets = [
torch.cat([
bbox_targets_level_img.reshape(-1, 4)
for bbox_targets_level_img in bbox_targets_level
]) for bbox_targets_level in zip(*bbox_target_list)
]
flatten_labels = torch.cat(flatten_labels)
flatten_bbox_targets = torch.cat(flatten_bbox_targets)
return flatten_labels, flatten_bbox_targets
def fovea_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
featmap_size_list=None,
point_list=None):
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
(gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
label_list = []
bbox_target_list = []
# for each pyramid, find the cls and box target
for base_len, (lower_bound, upper_bound), stride, featmap_size, \
(y, x) in zip(self.base_edge_list, self.scale_ranges,
self.strides, featmap_size_list, point_list):
labels = gt_labels_raw.new_zeros(featmap_size)
bbox_targets = gt_bboxes_raw.new(featmap_size[0], featmap_size[1],
4) + 1
# scale assignment
hit_indices = ((gt_areas >= lower_bound) &
(gt_areas <= upper_bound)).nonzero().flatten()
if len(hit_indices) == 0:
label_list.append(labels)
bbox_target_list.append(torch.log(bbox_targets))
continue
_, hit_index_order = torch.sort(-gt_areas[hit_indices])
hit_indices = hit_indices[hit_index_order]
gt_bboxes = gt_bboxes_raw[hit_indices, :] / stride
gt_labels = gt_labels_raw[hit_indices]
half_w = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0])
half_h = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1])
# valid fovea area: left, right, top, down
pos_left = torch.ceil(
gt_bboxes[:, 0] + (1 - self.sigma) * half_w - 0.5).long().\
clamp(0, featmap_size[1] - 1)
pos_right = torch.floor(
gt_bboxes[:, 0] + (1 + self.sigma) * half_w - 0.5).long().\
clamp(0, featmap_size[1] - 1)
pos_top = torch.ceil(
gt_bboxes[:, 1] + (1 - self.sigma) * half_h - 0.5).long().\
clamp(0, featmap_size[0] - 1)
pos_down = torch.floor(
gt_bboxes[:, 1] + (1 + self.sigma) * half_h - 0.5).long().\
clamp(0, featmap_size[0] - 1)
for px1, py1, px2, py2, label, (gt_x1, gt_y1, gt_x2, gt_y2) in \
zip(pos_left, pos_top, pos_right, pos_down, gt_labels,
gt_bboxes_raw[hit_indices, :]):
labels[py1:py2 + 1, px1:px2 + 1] = label
bbox_targets[py1:py2 + 1, px1:px2 + 1, 0] = \
(stride * x[py1:py2 + 1, px1:px2 + 1] - gt_x1) / base_len
bbox_targets[py1:py2 + 1, px1:px2 + 1, 1] = \
(stride * y[py1:py2 + 1, px1:px2 + 1] - gt_y1) / base_len
bbox_targets[py1:py2 + 1, px1:px2 + 1, 2] = \
(gt_x2 - stride * x[py1:py2 + 1, px1:px2 + 1]) / base_len
bbox_targets[py1:py2 + 1, px1:px2 + 1, 3] = \
(gt_y2 - stride * y[py1:py2 + 1, px1:px2 + 1]) / base_len
bbox_targets = bbox_targets.clamp(min=1. / 16, max=16.)
label_list.append(labels)
bbox_target_list.append(torch.log(bbox_targets))
return label_list, bbox_target_list
def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg, rescale=None):
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
points = self.get_points(
featmap_sizes,
bbox_preds[0].dtype,
bbox_preds[0].device,
flatten=True)
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
det_bboxes = self.get_bboxes_single(cls_score_list, bbox_pred_list,
featmap_sizes, points,
img_shape, scale_factor, cfg,
rescale)
result_list.append(det_bboxes)
return result_list
def get_bboxes_single(self,
cls_scores,
bbox_preds,
featmap_sizes,
point_list,
img_shape,
scale_factor,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(point_list)
det_bboxes = []
det_scores = []
for cls_score, bbox_pred, featmap_size, stride, base_len, (y, x) \
in zip(cls_scores, bbox_preds, featmap_sizes, self.strides,
self.base_edge_list, point_list):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels).sigmoid()
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4).exp()
nms_pre = cfg.get('nms_pre', -1)
if (nms_pre > 0) and (scores.shape[0] > nms_pre):
max_scores, _ = scores.max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
y = y[topk_inds]
x = x[topk_inds]
x1 = (stride * x - base_len * bbox_pred[:, 0]).\
clamp(min=0, max=img_shape[1] - 1)
y1 = (stride * y - base_len * bbox_pred[:, 1]).\
clamp(min=0, max=img_shape[0] - 1)
x2 = (stride * x + base_len * bbox_pred[:, 2]).\
clamp(min=0, max=img_shape[1] - 1)
y2 = (stride * y + base_len * bbox_pred[:, 3]).\
clamp(min=0, max=img_shape[0] - 1)
bboxes = torch.stack([x1, y1, x2, y2], -1)
det_bboxes.append(bboxes)
det_scores.append(scores)
det_bboxes = torch.cat(det_bboxes)
if rescale:
det_bboxes /= det_bboxes.new_tensor(scale_factor)
det_scores = torch.cat(det_scores)
padding = det_scores.new_zeros(det_scores.shape[0], 1)
det_scores = torch.cat([padding, det_scores], dim=1)
det_bboxes, det_labels = multiclass_nms(det_bboxes, det_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
...@@ -4,6 +4,7 @@ from .double_head_rcnn import DoubleHeadRCNN ...@@ -4,6 +4,7 @@ from .double_head_rcnn import DoubleHeadRCNN
from .fast_rcnn import FastRCNN from .fast_rcnn import FastRCNN
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN
from .fcos import FCOS from .fcos import FCOS
from .fovea import FOVEA
from .grid_rcnn import GridRCNN from .grid_rcnn import GridRCNN
from .htc import HybridTaskCascade from .htc import HybridTaskCascade
from .mask_rcnn import MaskRCNN from .mask_rcnn import MaskRCNN
...@@ -18,5 +19,5 @@ __all__ = [ ...@@ -18,5 +19,5 @@ __all__ = [
'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
'DoubleHeadRCNN', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'DoubleHeadRCNN', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN',
'RepPointsDetector' 'RepPointsDetector', 'FOVEA'
] ]
from ..registry import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module
class FOVEA(SingleStageDetector):
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(FOVEA, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment