"README_cn.md" did not exist on "ad08b8ce131bacd6f61dfcd49e5f1af3cac76ca7"
Commit 89022a26 authored by Jiaqi Wang's avatar Jiaqi Wang Committed by Kai Chen
Browse files

Code of CVPR 2019 Paper: Region Proposal by Guided Anchoring (#594)

* add two stage w/o neck and w/ upperneck

* add rpn r50 c4

* update c4 configs

* fix

* config update

* update config

* minor update

* mask rcnn support c4 train and test

* lr fix

* cascade support upper_neck

* add cascade c4 config

* update config

* update

* update res_layer to new interface

* refactoring

* c4 configs update

* refactoring

* update rpn_c4 config

* rename upper_neck as shared_head

* update

* update configs

* update

* update c4 configs

* update according to commits

* update

* add ga rpn

* test bug fix

* test bug fix with loc_filter_thr is large

* update configs

* update configs

* add ga_retinanet

* ga test bug fix

* update configs

* update

* init masked conv

* update

* update masked conv

* update

* support no ga_sampler

* update

* update

* test with masked_conv

* update comment

* fix flake errors

* fix flake 8 errors

* refactor bounded iou loss

* refactor ga_retina_head

* update configs

* refactor masked conv

* fix flake8 error

* refactor guided_anchor_head and ga_rpn_head

* update configs

* use_sigmoid_cls -> cls_sigmoid_loss; use_focal_loss -> cls_focal_loss

* refactoring

* cls_sigmoid_loss -> use_sigmoid_cls

* fix flake8 error

* add some docs

* rename normalize to norm_cfg

* update configs

* add readme

* update ga_faster config

* update readme

* update readme

* rename configs as r50_caffe

* merge master

* refactor guided anchor target

* update readme

* update approx mas iou assigner

* refactor guided anchor target

* update docstring

* refactor ga heads

* fix flake8 error

* update readme

* update model url

* update comments

* refactor get anchors

* update docstring

* not use_loc_filter during training

* add R-101 results

* update to support build loss api

* fix flake8 error

* update readme with x-101 performances

* update readme

* add a link in project readme

* refactor code about ga shape inside flags

* update

* update

* add x101 config files

* add ga_rpn r101 config

* update some comments

* add comments

* add comments

* update comments

* fix flake8 error
parent 3cb84acc
This diff is collapsed.
from .cross_entropy_loss import CrossEntropyLoss
from .focal_loss import FocalLoss
from .smooth_l1_loss import SmoothL1Loss
from .iou_loss import IoULoss
__all__ = ['CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss']
__all__ = ['CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss', 'IoULoss']
import torch.nn as nn
from mmdet.core import weighted_iou_loss
from ..registry import LOSSES
@LOSSES.register_module
class IoULoss(nn.Module):
def __init__(self, style='naive', beta=0.2, eps=1e-3, loss_weight=1.0):
super(IoULoss, self).__init__()
self.style = style
self.beta = beta
self.eps = eps
self.loss_weight = loss_weight
def forward(self, pred, target, weight, *args, **kwargs):
loss = self.loss_weight * weighted_iou_loss(
pred,
target,
weight,
beta=self.beta,
eps=self.eps,
*args,
**kwargs)
return loss
......@@ -6,11 +6,13 @@ from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool
from .sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss
from .masked_conv import MaskedConv2d
__all__ = [
'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv',
'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss'
'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss',
'MaskedConv2d'
]
from .functions.masked_conv import masked_conv2d
from .modules.masked_conv import MaskedConv2d
__all__ = ['masked_conv2d', 'MaskedConv2d']
import math
import torch
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import masked_conv2d_cuda
class MaskedConv2dFunction(Function):
@staticmethod
def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
assert mask.dim() == 3 and mask.size(0) == 1
assert features.dim() == 4 and features.size(0) == 1
assert features.size()[2:] == mask.size()[1:]
pad_h, pad_w = _pair(padding)
stride_h, stride_w = _pair(stride)
if stride_h != 1 or stride_w != 1:
raise ValueError(
'Stride could not only be 1 in masked_conv2d currently.')
if not features.is_cuda:
raise NotImplementedError
out_channel, in_channel, kernel_h, kernel_w = weight.size()
batch_size = features.size(0)
out_h = int(
math.floor((features.size(2) + 2 * pad_h -
(kernel_h - 1) - 1) / stride_h + 1))
out_w = int(
math.floor((features.size(3) + 2 * pad_w -
(kernel_h - 1) - 1) / stride_w + 1))
mask_inds = torch.nonzero(mask[0] > 0)
mask_h_idx = mask_inds[:, 0].contiguous()
mask_w_idx = mask_inds[:, 1].contiguous()
data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
mask_inds.size(0))
masked_conv2d_cuda.masked_im2col_forward(features, mask_h_idx,
mask_w_idx, kernel_h,
kernel_w, pad_h, pad_w,
data_col)
masked_output = torch.addmm(1, bias[:, None], 1,
weight.view(out_channel, -1), data_col)
output = features.new_zeros(batch_size, out_channel, out_h, out_w)
masked_conv2d_cuda.masked_col2im_forward(masked_output, mask_h_idx,
mask_w_idx, out_h, out_w,
out_channel, output)
return output
@staticmethod
def backward(ctx, grad_output):
return (None, ) * 5
masked_conv2d = MaskedConv2dFunction.apply
import torch.nn as nn
from ..functions.masked_conv import masked_conv2d
class MaskedConv2d(nn.Conv2d):
"""A MaskedConv2d which inherits the official Conv2d.
The masked forward doesn't implement the backward function and only
supports the stride parameter to be 1 currently.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
super(MaskedConv2d,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, input, mask=None):
if mask is None: # fallback to the normal Conv2d
return super(MaskedConv2d, self).forward(input)
else:
return masked_conv2d(input, mask, self.weight, self.bias,
self.padding)
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='masked_conv2d_cuda',
ext_modules=[
CUDAExtension('masked_conv2d_cuda', [
'src/masked_conv2d_cuda.cpp',
'src/masked_conv2d_kernel.cu',
]),
],
cmdclass={'build_ext': BuildExtension})
#include <torch/extension.h>
#include <cmath>
#include <vector>
int MaskedIm2colForwardLaucher(const at::Tensor im, const int height,
const int width, const int channels,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, const int mask_cnt,
at::Tensor col);
int MaskedCol2imForwardLaucher(const at::Tensor col, const int height,
const int width, const int channels,
const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, const int mask_cnt,
at::Tensor im);
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
int masked_im2col_forward_cuda(const at::Tensor im, const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, const int kernel_h,
const int kernel_w, const int pad_h,
const int pad_w, at::Tensor col) {
CHECK_INPUT(im);
CHECK_INPUT(mask_h_idx);
CHECK_INPUT(mask_w_idx);
CHECK_INPUT(col);
// im: (n, ic, h, w), kernel size (kh, kw)
// kernel: (oc, ic * kh * kw), col: (kh * kw * ic, ow * oh)
int channels = im.size(1);
int height = im.size(2);
int width = im.size(3);
int mask_cnt = mask_h_idx.size(0);
MaskedIm2colForwardLaucher(im, height, width, channels, kernel_h, kernel_w,
pad_h, pad_w, mask_h_idx, mask_w_idx, mask_cnt,
col);
return 1;
}
int masked_col2im_forward_cuda(const at::Tensor col,
const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, int height,
int width, int channels, at::Tensor im) {
CHECK_INPUT(col);
CHECK_INPUT(mask_h_idx);
CHECK_INPUT(mask_w_idx);
CHECK_INPUT(im);
// im: (n, ic, h, w), kernel size (kh, kw)
// kernel: (oc, ic * kh * kh), col: (kh * kw * ic, ow * oh)
int mask_cnt = mask_h_idx.size(0);
MaskedCol2imForwardLaucher(col, height, width, channels, mask_h_idx,
mask_w_idx, mask_cnt, im);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("masked_im2col_forward", &masked_im2col_forward_cuda,
"masked_im2col forward (CUDA)");
m.def("masked_col2im_forward", &masked_col2im_forward_cuda,
"masked_col2im forward (CUDA)");
}
\ No newline at end of file
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
#define THREADS_PER_BLOCK 1024
inline int GET_BLOCKS(const int N) {
int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
int max_block_num = 65000;
return min(optimal_block_num, max_block_num);
}
template <typename scalar_t>
__global__ void MaskedIm2colForward(const int n, const scalar_t *data_im,
const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const long *mask_h_idx,
const long *mask_w_idx, const int mask_cnt,
scalar_t *data_col) {
// mask_cnt * channels
CUDA_1D_KERNEL_LOOP(index, n) {
const int m_index = index % mask_cnt;
const int h_col = mask_h_idx[m_index];
const int w_col = mask_w_idx[m_index];
const int c_im = index / mask_cnt;
const int c_col = c_im * kernel_h * kernel_w;
const int h_offset = h_col - pad_h;
const int w_offset = w_col - pad_w;
scalar_t *data_col_ptr = data_col + c_col * mask_cnt + m_index;
for (int i = 0; i < kernel_h; ++i) {
int h_im = h_offset + i;
for (int j = 0; j < kernel_w; ++j) {
int w_im = w_offset + j;
if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
*data_col_ptr =
(scalar_t)data_im[(c_im * height + h_im) * width + w_im];
} else {
*data_col_ptr = 0.0;
}
data_col_ptr += mask_cnt;
}
}
}
}
int MaskedIm2colForwardLaucher(const at::Tensor bottom_data, const int height,
const int width, const int channels,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, const int mask_cnt,
at::Tensor top_data) {
const int output_size = mask_cnt * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
bottom_data.type(), "MaskedIm2colLaucherForward", ([&] {
const scalar_t *bottom_data_ = bottom_data.data<scalar_t>();
const long *mask_h_idx_ = mask_h_idx.data<long>();
const long *mask_w_idx_ = mask_w_idx.data<long>();
scalar_t *top_data_ = top_data.data<scalar_t>();
MaskedIm2colForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data_, height, width, kernel_h, kernel_w,
pad_h, pad_w, mask_h_idx_, mask_w_idx_, mask_cnt, top_data_);
}));
THCudaCheck(cudaGetLastError());
return 1;
}
template <typename scalar_t>
__global__ void MaskedCol2imForward(const int n, const scalar_t *data_col,
const int height, const int width,
const int channels, const long *mask_h_idx,
const long *mask_w_idx, const int mask_cnt,
scalar_t *data_im) {
CUDA_1D_KERNEL_LOOP(index, n) {
const int m_index = index % mask_cnt;
const int h_im = mask_h_idx[m_index];
const int w_im = mask_w_idx[m_index];
const int c_im = index / mask_cnt;
// int kernel_extent_w = (kernel_w - 1) + 1;
// int kernel_extent_h = (kernel_h - 1) + 1;
// compute the start and end of the output
data_im[(c_im * height + h_im) * width + w_im] = data_col[index];
}
}
int MaskedCol2imForwardLaucher(const at::Tensor bottom_data, const int height,
const int width, const int channels,
const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, const int mask_cnt,
at::Tensor top_data) {
const int output_size = mask_cnt * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
bottom_data.type(), "MaskedCol2imLaucherForward", ([&] {
const scalar_t *bottom_data_ = bottom_data.data<scalar_t>();
const long *mask_h_idx_ = mask_h_idx.data<long>();
const long *mask_w_idx_ = mask_w_idx.data<long>();
scalar_t *top_data_ = top_data.data<scalar_t>();
MaskedCol2imForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data_, height, width, channels, mask_h_idx_,
mask_w_idx_, mask_cnt, top_data_);
}));
THCudaCheck(cudaGetLastError());
return 1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment