"...csrc/io/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "4cb3d802c13f940b05fdb3c59af7bc148d680f8a"
Unverified Commit 86cc430a authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Restructure the ops directory (#1073)

* restructure the ops directory

* add some repr strings
parent 8387aba8
...@@ -2,7 +2,7 @@ from .dcn import (DeformConv, DeformConvPack, ModulatedDeformConv, ...@@ -2,7 +2,7 @@ from .dcn import (DeformConv, DeformConvPack, ModulatedDeformConv,
ModulatedDeformConvPack, DeformRoIPooling, ModulatedDeformConvPack, DeformRoIPooling,
DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack, DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack,
deform_conv, modulated_deform_conv, deform_roi_pooling) deform_conv, modulated_deform_conv, deform_roi_pooling)
from .gcb import ContextBlock from .context_block import ContextBlock
from .nms import nms, soft_nms from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool from .roi_pool import RoIPool, roi_pool
......
from .functions.deform_conv import deform_conv, modulated_deform_conv from .deform_conv import (deform_conv, modulated_deform_conv, DeformConv,
from .functions.deform_pool import deform_roi_pooling DeformConvPack, ModulatedDeformConv,
from .modules.deform_conv import (DeformConv, ModulatedDeformConv, ModulatedDeformConvPack)
DeformConvPack, ModulatedDeformConvPack) from .deform_pool import (deform_roi_pooling, DeformRoIPooling,
from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack, DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack)
ModulatedDeformRoIPoolingPack)
__all__ = [ __all__ = [
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv',
......
import math
import torch import torch
import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from .. import deform_conv_cuda from . import deform_conv_cuda
class DeformConvFunction(Function): class DeformConvFunction(Function):
...@@ -52,6 +56,7 @@ class DeformConvFunction(Function): ...@@ -52,6 +56,7 @@ class DeformConvFunction(Function):
return output return output
@staticmethod @staticmethod
@once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors input, offset, weight = ctx.saved_tensors
...@@ -143,6 +148,7 @@ class ModulatedDeformConvFunction(Function): ...@@ -143,6 +148,7 @@ class ModulatedDeformConvFunction(Function):
return output return output
@staticmethod @staticmethod
@once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output):
if not grad_output.is_cuda: if not grad_output.is_cuda:
raise NotImplementedError raise NotImplementedError
...@@ -179,3 +185,153 @@ class ModulatedDeformConvFunction(Function): ...@@ -179,3 +185,153 @@ class ModulatedDeformConvFunction(Function):
deform_conv = DeformConvFunction.apply deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply modulated_deform_conv = ModulatedDeformConvFunction.apply
class DeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False):
super(DeformConv, self).__init__()
assert not bias
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups,
*self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, offset):
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
class DeformConvPack(DeformConv):
def __init__(self, *args, **kwargs):
super(DeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
class ModulatedDeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups,
*self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, x):
out = self.conv_offset_mask(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
from torch import nn import torch
import torch.nn as nn
from ..functions.deform_pool import deform_roi_pooling from torch.autograd import Function
from torch.autograd.function import once_differentiable
from . import deform_pool_cuda
class DeformRoIPoolingFunction(Function):
@staticmethod
def forward(ctx,
data,
rois,
offset,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
ctx.spatial_scale = spatial_scale
ctx.out_size = out_size
ctx.out_channels = out_channels
ctx.no_trans = no_trans
ctx.group_size = group_size
ctx.part_size = out_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std
assert 0.0 <= ctx.trans_std <= 1.0
if not data.is_cuda:
raise NotImplementedError
n = rois.shape[0]
output = data.new_empty(n, out_channels, out_size, out_size)
output_count = data.new_empty(n, out_channels, out_size, out_size)
deform_pool_cuda.deform_psroi_pooling_cuda_forward(
data, rois, offset, output, output_count, ctx.no_trans,
ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
ctx.part_size, ctx.sample_per_part, ctx.trans_std)
if data.requires_grad or rois.requires_grad or offset.requires_grad:
ctx.save_for_backward(data, rois, offset)
ctx.output_count = output_count
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
data, rois, offset = ctx.saved_tensors
output_count = ctx.output_count
grad_input = torch.zeros_like(data)
grad_rois = None
grad_offset = torch.zeros_like(offset)
deform_pool_cuda.deform_psroi_pooling_cuda_backward(
grad_output, data, rois, offset, output_count, grad_input,
grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
ctx.trans_std)
return (grad_input, grad_rois, grad_offset, None, None, None, None,
None, None, None, None)
deform_roi_pooling = DeformRoIPoolingFunction.apply
class DeformRoIPooling(nn.Module): class DeformRoIPooling(nn.Module):
...@@ -27,10 +96,11 @@ class DeformRoIPooling(nn.Module): ...@@ -27,10 +96,11 @@ class DeformRoIPooling(nn.Module):
def forward(self, data, rois, offset): def forward(self, data, rois, offset):
if self.no_trans: if self.no_trans:
offset = data.new_empty(0) offset = data.new_empty(0)
return deform_roi_pooling( return deform_roi_pooling(data, rois, offset, self.spatial_scale,
data, rois, offset, self.spatial_scale, self.out_size, self.out_size, self.out_channels,
self.out_channels, self.no_trans, self.group_size, self.part_size, self.no_trans, self.group_size,
self.sample_per_part, self.trans_std) self.part_size, self.sample_per_part,
self.trans_std)
class DeformRoIPoolingPack(DeformRoIPooling): class DeformRoIPoolingPack(DeformRoIPooling):
...@@ -73,10 +143,11 @@ class DeformRoIPoolingPack(DeformRoIPooling): ...@@ -73,10 +143,11 @@ class DeformRoIPoolingPack(DeformRoIPooling):
assert data.size(1) == self.out_channels assert data.size(1) == self.out_channels
if self.no_trans: if self.no_trans:
offset = data.new_empty(0) offset = data.new_empty(0)
return deform_roi_pooling( return deform_roi_pooling(data, rois, offset, self.spatial_scale,
data, rois, offset, self.spatial_scale, self.out_size, self.out_size, self.out_channels,
self.out_channels, self.no_trans, self.group_size, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) self.part_size, self.sample_per_part,
self.trans_std)
else: else:
n = rois.shape[0] n = rois.shape[0]
offset = data.new_empty(0) offset = data.new_empty(0)
...@@ -86,10 +157,11 @@ class DeformRoIPoolingPack(DeformRoIPooling): ...@@ -86,10 +157,11 @@ class DeformRoIPoolingPack(DeformRoIPooling):
self.sample_per_part, self.trans_std) self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1)) offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size) offset = offset.view(n, 2, self.out_size, self.out_size)
return deform_roi_pooling( return deform_roi_pooling(data, rois, offset, self.spatial_scale,
data, rois, offset, self.spatial_scale, self.out_size, self.out_size, self.out_channels,
self.out_channels, self.no_trans, self.group_size, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) self.part_size, self.sample_per_part,
self.trans_std)
class ModulatedDeformRoIPoolingPack(DeformRoIPooling): class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
...@@ -106,9 +178,9 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -106,9 +178,9 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
num_offset_fcs=3, num_offset_fcs=3,
num_mask_fcs=2, num_mask_fcs=2,
deform_fc_channels=1024): deform_fc_channels=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__( super(ModulatedDeformRoIPoolingPack,
spatial_scale, out_size, out_channels, no_trans, group_size, self).__init__(spatial_scale, out_size, out_channels, no_trans,
part_size, sample_per_part, trans_std) group_size, part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs self.num_offset_fcs = num_offset_fcs
self.num_mask_fcs = num_mask_fcs self.num_mask_fcs = num_mask_fcs
...@@ -151,10 +223,11 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -151,10 +223,11 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
assert data.size(1) == self.out_channels assert data.size(1) == self.out_channels
if self.no_trans: if self.no_trans:
offset = data.new_empty(0) offset = data.new_empty(0)
return deform_roi_pooling( return deform_roi_pooling(data, rois, offset, self.spatial_scale,
data, rois, offset, self.spatial_scale, self.out_size, self.out_size, self.out_channels,
self.out_channels, self.no_trans, self.group_size, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) self.part_size, self.sample_per_part,
self.trans_std)
else: else:
n = rois.shape[0] n = rois.shape[0]
offset = data.new_empty(0) offset = data.new_empty(0)
......
import torch
from torch.autograd import Function
from .. import deform_pool_cuda
class DeformRoIPoolingFunction(Function):
@staticmethod
def forward(ctx,
data,
rois,
offset,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
ctx.spatial_scale = spatial_scale
ctx.out_size = out_size
ctx.out_channels = out_channels
ctx.no_trans = no_trans
ctx.group_size = group_size
ctx.part_size = out_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std
assert 0.0 <= ctx.trans_std <= 1.0
if not data.is_cuda:
raise NotImplementedError
n = rois.shape[0]
output = data.new_empty(n, out_channels, out_size, out_size)
output_count = data.new_empty(n, out_channels, out_size, out_size)
deform_pool_cuda.deform_psroi_pooling_cuda_forward(
data, rois, offset, output, output_count, ctx.no_trans,
ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
ctx.part_size, ctx.sample_per_part, ctx.trans_std)
if data.requires_grad or rois.requires_grad or offset.requires_grad:
ctx.save_for_backward(data, rois, offset)
ctx.output_count = output_count
return output
@staticmethod
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
data, rois, offset = ctx.saved_tensors
output_count = ctx.output_count
grad_input = torch.zeros_like(data)
grad_rois = None
grad_offset = torch.zeros_like(offset)
deform_pool_cuda.deform_psroi_pooling_cuda_backward(
grad_output, data, rois, offset, output_count, grad_input,
grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
ctx.trans_std)
return (grad_input, grad_rois, grad_offset, None, None, None, None,
None, None, None, None)
deform_roi_pooling = DeformRoIPoolingFunction.apply
import math
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
from ..functions.deform_conv import deform_conv, modulated_deform_conv
class DeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False):
super(DeformConv, self).__init__()
assert not bias
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups,
*self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, offset):
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
class DeformConvPack(DeformConv):
def __init__(self, *args, **kwargs):
super(DeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
class ModulatedDeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups,
*self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, x):
out = self.conv_offset_mask(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
from .context_block import ContextBlock
__all__ = [
'ContextBlock',
]
from .functions.masked_conv import masked_conv2d from .masked_conv import masked_conv2d, MaskedConv2d
from .modules.masked_conv import MaskedConv2d
__all__ = ['masked_conv2d', 'MaskedConv2d'] __all__ = ['masked_conv2d', 'MaskedConv2d']
import math import math
import torch import torch
import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from .. import masked_conv2d_cuda
from . import masked_conv2d_cuda
class MaskedConv2dFunction(Function): class MaskedConv2dFunction(Function):
...@@ -49,8 +53,37 @@ class MaskedConv2dFunction(Function): ...@@ -49,8 +53,37 @@ class MaskedConv2dFunction(Function):
return output return output
@staticmethod @staticmethod
@once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output):
return (None, ) * 5 return (None, ) * 5
masked_conv2d = MaskedConv2dFunction.apply masked_conv2d = MaskedConv2dFunction.apply
class MaskedConv2d(nn.Conv2d):
"""A MaskedConv2d which inherits the official Conv2d.
The masked forward doesn't implement the backward function and only
supports the stride parameter to be 1 currently.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
super(MaskedConv2d,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, input, mask=None):
if mask is None: # fallback to the normal Conv2d
return super(MaskedConv2d, self).forward(input)
else:
return masked_conv2d(input, mask, self.weight, self.bias,
self.padding)
import torch.nn as nn
from ..functions.masked_conv import masked_conv2d
class MaskedConv2d(nn.Conv2d):
"""A MaskedConv2d which inherits the official Conv2d.
The masked forward doesn't implement the backward function and only
supports the stride parameter to be 1 currently.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
super(MaskedConv2d,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, input, mask=None):
if mask is None: # fallback to the normal Conv2d
return super(MaskedConv2d, self).forward(input)
else:
return masked_conv2d(input, mask, self.weight, self.bias,
self.padding)
from .functions.roi_align import roi_align from .roi_align import roi_align, RoIAlign
from .modules.roi_align import RoIAlign
__all__ = ['roi_align', 'RoIAlign'] __all__ = ['roi_align', 'RoIAlign']
import torch.nn as nn
from torch.nn.modules.utils import _pair
from ..functions.roi_align import roi_align
class RoIAlign(nn.Module):
def __init__(self,
out_size,
spatial_scale,
sample_num=0,
use_torchvision=False):
super(RoIAlign, self).__init__()
self.out_size = out_size
self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num)
self.use_torchvision = use_torchvision
def forward(self, features, rois):
if self.use_torchvision:
from torchvision.ops import roi_align as tv_roi_align
return tv_roi_align(features, rois, _pair(self.out_size),
self.spatial_scale, self.sample_num)
else:
return roi_align(features, rois, self.out_size, self.spatial_scale,
self.sample_num)
import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from .. import roi_align_cuda from . import roi_align_cuda
class RoIAlignFunction(Function): class RoIAlignFunction(Function):
...@@ -28,6 +30,7 @@ class RoIAlignFunction(Function): ...@@ -28,6 +30,7 @@ class RoIAlignFunction(Function):
return output return output
@staticmethod @staticmethod
@once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output):
feature_size = ctx.feature_size feature_size = ctx.feature_size
spatial_scale = ctx.spatial_scale spatial_scale = ctx.spatial_scale
...@@ -51,3 +54,34 @@ class RoIAlignFunction(Function): ...@@ -51,3 +54,34 @@ class RoIAlignFunction(Function):
roi_align = RoIAlignFunction.apply roi_align = RoIAlignFunction.apply
class RoIAlign(nn.Module):
def __init__(self,
out_size,
spatial_scale,
sample_num=0,
use_torchvision=False):
super(RoIAlign, self).__init__()
self.out_size = out_size
self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num)
self.use_torchvision = use_torchvision
def forward(self, features, rois):
if self.use_torchvision:
from torchvision.ops import roi_align as tv_roi_align
return tv_roi_align(features, rois, _pair(self.out_size),
self.spatial_scale, self.sample_num)
else:
return roi_align(features, rois, self.out_size, self.spatial_scale,
self.sample_num)
def __repr__(self):
format_str = self.__class__.__name__
format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format(
self.out_size, self.spatial_scale, self.sample_num)
format_str += ', use_torchvision={})'.format(self.use_torchvision)
return format_str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment