Unverified Commit 86cc430a authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Restructure the ops directory (#1073)

* restructure the ops directory

* add some repr strings
parent 8387aba8
from .functions.roi_pool import roi_pool from .roi_pool import roi_pool, RoIPool
from .modules.roi_pool import RoIPool
__all__ = ['roi_pool', 'RoIPool'] __all__ = ['roi_pool', 'RoIPool']
import torch.nn as nn
from torch.nn.modules.utils import _pair
from ..functions.roi_pool import roi_pool
class RoIPool(nn.Module):
def __init__(self, out_size, spatial_scale, use_torchvision=False):
super(RoIPool, self).__init__()
self.out_size = out_size
self.spatial_scale = float(spatial_scale)
self.use_torchvision = use_torchvision
def forward(self, features, rois):
if self.use_torchvision:
from torchvision.ops import roi_pool as tv_roi_pool
return tv_roi_pool(features, rois, _pair(self.out_size),
self.spatial_scale)
else:
return roi_pool(features, rois, self.out_size, self.spatial_scale)
import torch import torch
import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from .. import roi_pool_cuda from . import roi_pool_cuda
class RoIPoolFunction(Function): class RoIPoolFunction(Function):
...@@ -27,6 +29,7 @@ class RoIPoolFunction(Function): ...@@ -27,6 +29,7 @@ class RoIPoolFunction(Function):
return output return output
@staticmethod @staticmethod
@once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output):
assert grad_output.is_cuda assert grad_output.is_cuda
spatial_scale = ctx.spatial_scale spatial_scale = ctx.spatial_scale
...@@ -45,3 +48,28 @@ class RoIPoolFunction(Function): ...@@ -45,3 +48,28 @@ class RoIPoolFunction(Function):
roi_pool = RoIPoolFunction.apply roi_pool = RoIPoolFunction.apply
class RoIPool(nn.Module):
def __init__(self, out_size, spatial_scale, use_torchvision=False):
super(RoIPool, self).__init__()
self.out_size = out_size
self.spatial_scale = float(spatial_scale)
self.use_torchvision = use_torchvision
def forward(self, features, rois):
if self.use_torchvision:
from torchvision.ops import roi_pool as tv_roi_pool
return tv_roi_pool(features, rois, _pair(self.out_size),
self.spatial_scale)
else:
return roi_pool(features, rois, self.out_size, self.spatial_scale)
def __repr__(self):
format_str = self.__class__.__name__
format_str += '(out_size={}, spatial_scale={}'.format(
self.out_size, self.spatial_scale)
format_str += ', use_torchvision={})'.format(self.use_torchvision)
return format_str
from .modules.sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss from .sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss
__all__ = ['SigmoidFocalLoss', 'sigmoid_focal_loss'] __all__ = ['SigmoidFocalLoss', 'sigmoid_focal_loss']
from torch import nn
from ..functions.sigmoid_focal_loss import sigmoid_focal_loss
# TODO: remove this module
class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha):
super(SigmoidFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, logits, targets):
assert logits.is_cuda
loss = sigmoid_focal_loss(logits, targets, self.gamma, self.alpha)
return loss.sum()
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "gamma=" + str(self.gamma)
tmpstr += ", alpha=" + str(self.alpha)
tmpstr += ")"
return tmpstr
import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from .. import sigmoid_focal_loss_cuda from . import sigmoid_focal_loss_cuda
class SigmoidFocalLossFunction(Function): class SigmoidFocalLossFunction(Function):
...@@ -32,3 +33,22 @@ class SigmoidFocalLossFunction(Function): ...@@ -32,3 +33,22 @@ class SigmoidFocalLossFunction(Function):
sigmoid_focal_loss = SigmoidFocalLossFunction.apply sigmoid_focal_loss = SigmoidFocalLossFunction.apply
# TODO: remove this module
class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha):
super(SigmoidFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, logits, targets):
assert logits.is_cuda
loss = sigmoid_focal_loss(logits, targets, self.gamma, self.alpha)
return loss.sum()
def __repr__(self):
tmpstr = self.__class__.__name__ + '(gamma={}, alpha={})'.format(
self.gamma, self.alpha)
return tmpstr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment