"vscode:/vscode.git/clone" did not exist on "77e0c654e914ddb303d4f5671886aa13f84ca498"
Commit 9acb38be authored by yhcao6's avatar yhcao6
Browse files

expose dcn extension interfaces

parent 6bd60eac
from .dcn import (DeformConv, DeformRoIPooling, ModulatedDeformRoIPoolingPack,
ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
modulated_deform_conv, deform_roi_pooling)
from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool
__all__ = ['nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool']
__all__ = [
'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
'DeformConv', 'DeformRoIPooling', 'ModulatedDeformRoIPoolingPack',
'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv', 'deform_roi_pooling'
]
from .functions import deform_conv
from .functions.modulated_dcn_func import (modulated_deform_conv,
deform_roi_pooling)
from .modules.deform_conv import DeformConv
from .modules.modulated_dcn import (
DeformRoIPooling, ModulatedDeformRoIPoolingPack, ModulatedDeformConv,
ModulatedDeformConvPack)
__all__ = [
'DeformConv', 'DeformRoIPooling', 'ModulatedDeformRoIPoolingPack',
'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv', 'deform_roi_pooling'
]
......@@ -2,7 +2,6 @@ import math
import torch
import torch.nn as nn
from mmcv.cnn import uniform_init
from torch.nn.modules.module import Module
from torch.nn.modules.utils import _pair
......@@ -38,7 +37,7 @@ class DeformConv(Module):
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
uniform_init(self, -stdv, stdv)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride,
......
......@@ -6,7 +6,6 @@ from __future__ import print_function
import math
import torch
from mmcv.cnn import uniform_init
from torch import nn
from torch.nn.modules.utils import _pair
......@@ -47,7 +46,8 @@ class ModulatedDeformConv(nn.Module):
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
uniform_init(self, -stdv, stdv)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
def forward(self, input, offset, mask):
return modulated_deform_conv(input, offset, mask, self.weight,
......
......@@ -2,9 +2,9 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='deform_conv',
name='deform_conv_cuda',
ext_modules=[
CUDAExtension('deform_conv', [
CUDAExtension('deform_conv_cuda', [
'src/deform_conv_cuda.cpp',
'src/deform_conv_cuda_kernel.cu',
]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment