Commit 9acb38be authored by yhcao6's avatar yhcao6
Browse files

expose dcn extension interfaces

parent 6bd60eac
from .dcn import (DeformConv, DeformRoIPooling, ModulatedDeformRoIPoolingPack,
ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
modulated_deform_conv, deform_roi_pooling)
from .nms import nms, soft_nms from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool from .roi_pool import RoIPool, roi_pool
__all__ = ['nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool'] __all__ = [
'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
'DeformConv', 'DeformRoIPooling', 'ModulatedDeformRoIPoolingPack',
'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv', 'deform_roi_pooling'
]
from .functions import deform_conv
from .functions.modulated_dcn_func import (modulated_deform_conv,
deform_roi_pooling)
from .modules.deform_conv import DeformConv
from .modules.modulated_dcn import (
DeformRoIPooling, ModulatedDeformRoIPoolingPack, ModulatedDeformConv,
ModulatedDeformConvPack)
__all__ = [
'DeformConv', 'DeformRoIPooling', 'ModulatedDeformRoIPoolingPack',
'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv', 'deform_roi_pooling'
]
...@@ -2,7 +2,6 @@ import math ...@@ -2,7 +2,6 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import uniform_init
from torch.nn.modules.module import Module from torch.nn.modules.module import Module
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
...@@ -38,7 +37,7 @@ class DeformConv(Module): ...@@ -38,7 +37,7 @@ class DeformConv(Module):
for k in self.kernel_size: for k in self.kernel_size:
n *= k n *= k
stdv = 1. / math.sqrt(n) stdv = 1. / math.sqrt(n)
uniform_init(self, -stdv, stdv) self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, offset): def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride, return deform_conv(input, offset, self.weight, self.stride,
......
...@@ -6,7 +6,6 @@ from __future__ import print_function ...@@ -6,7 +6,6 @@ from __future__ import print_function
import math import math
import torch import torch
from mmcv.cnn import uniform_init
from torch import nn from torch import nn
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
...@@ -47,7 +46,8 @@ class ModulatedDeformConv(nn.Module): ...@@ -47,7 +46,8 @@ class ModulatedDeformConv(nn.Module):
for k in self.kernel_size: for k in self.kernel_size:
n *= k n *= k
stdv = 1. / math.sqrt(n) stdv = 1. / math.sqrt(n)
uniform_init(self, -stdv, stdv) self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
def forward(self, input, offset, mask): def forward(self, input, offset, mask):
return modulated_deform_conv(input, offset, mask, self.weight, return modulated_deform_conv(input, offset, mask, self.weight,
......
...@@ -2,9 +2,9 @@ from setuptools import setup ...@@ -2,9 +2,9 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup( setup(
name='deform_conv', name='deform_conv_cuda',
ext_modules=[ ext_modules=[
CUDAExtension('deform_conv', [ CUDAExtension('deform_conv_cuda', [
'src/deform_conv_cuda.cpp', 'src/deform_conv_cuda.cpp',
'src/deform_conv_cuda_kernel.cu', 'src/deform_conv_cuda_kernel.cu',
]), ]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment