"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "436efb8e9ea1b7100d316e58d4c48866fda3916d"
Unverified Commit 350fdd7a authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

support torchvision RoIPool and RoIAlign (#990)

parent c101398c
from torch.autograd import Function from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import roi_align_cuda from .. import roi_align_cuda
...@@ -7,17 +8,8 @@ class RoIAlignFunction(Function): ...@@ -7,17 +8,8 @@ class RoIAlignFunction(Function):
@staticmethod @staticmethod
def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0): def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0):
if isinstance(out_size, int): out_h, out_w = _pair(out_size)
out_h = out_size assert isinstance(out_h, int) and isinstance(out_w, int)
out_w = out_size
elif isinstance(out_size, tuple):
assert len(out_size) == 2
assert isinstance(out_size[0], int)
assert isinstance(out_size[1], int)
out_h, out_w = out_size
else:
raise TypeError(
'"out_size" must be an integer or tuple of integers')
ctx.spatial_scale = spatial_scale ctx.spatial_scale = spatial_scale
ctx.sample_num = sample_num ctx.sample_num = sample_num
ctx.save_for_backward(rois) ctx.save_for_backward(rois)
......
from torch.nn.modules.module import Module import torch.nn as nn
from ..functions.roi_align import RoIAlignFunction from torch.nn.modules.utils import _pair
from ..functions.roi_align import roi_align
class RoIAlign(Module):
def __init__(self, out_size, spatial_scale, sample_num=0): class RoIAlign(nn.Module):
def __init__(self,
out_size,
spatial_scale,
sample_num=0,
use_torchvision=False):
super(RoIAlign, self).__init__() super(RoIAlign, self).__init__()
self.out_size = out_size self.out_size = out_size
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num) self.sample_num = int(sample_num)
self.use_torchvision = use_torchvision
def forward(self, features, rois): def forward(self, features, rois):
return RoIAlignFunction.apply(features, rois, self.out_size, if self.use_torchvision:
self.spatial_scale, self.sample_num) from torchvision.ops import roi_align as tv_roi_align
return tv_roi_align(features, rois, _pair(self.out_size),
self.spatial_scale, self.sample_num)
else:
return roi_align(features, rois, self.out_size, self.spatial_scale,
self.sample_num)
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import roi_pool_cuda from .. import roi_pool_cuda
...@@ -8,18 +9,9 @@ class RoIPoolFunction(Function): ...@@ -8,18 +9,9 @@ class RoIPoolFunction(Function):
@staticmethod @staticmethod
def forward(ctx, features, rois, out_size, spatial_scale): def forward(ctx, features, rois, out_size, spatial_scale):
if isinstance(out_size, int):
out_h = out_size
out_w = out_size
elif isinstance(out_size, tuple):
assert len(out_size) == 2
assert isinstance(out_size[0], int)
assert isinstance(out_size[1], int)
out_h, out_w = out_size
else:
raise TypeError(
'"out_size" must be an integer or tuple of integers')
assert features.is_cuda assert features.is_cuda
out_h, out_w = _pair(out_size)
assert isinstance(out_h, int) and isinstance(out_w, int)
ctx.save_for_backward(rois) ctx.save_for_backward(rois)
num_channels = features.size(1) num_channels = features.size(1)
num_rois = rois.size(0) num_rois = rois.size(0)
......
from torch.nn.modules.module import Module import torch.nn as nn
from torch.nn.modules.utils import _pair
from ..functions.roi_pool import roi_pool from ..functions.roi_pool import roi_pool
class RoIPool(Module): class RoIPool(nn.Module):
def __init__(self, out_size, spatial_scale): def __init__(self, out_size, spatial_scale, use_torchvision=False):
super(RoIPool, self).__init__() super(RoIPool, self).__init__()
self.out_size = out_size self.out_size = out_size
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
self.use_torchvision = use_torchvision
def forward(self, features, rois): def forward(self, features, rois):
return roi_pool(features, rois, self.out_size, self.spatial_scale) if self.use_torchvision:
from torchvision.ops import roi_pool as tv_roi_pool
return tv_roi_pool(features, rois, _pair(self.out_size),
self.spatial_scale)
else:
return roi_pool(features, rois, self.out_size, self.spatial_scale)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment