"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "437b6e702fda37aa8d1e230601c85e989f858a8c"
Unverified Commit 85c0f85b authored by WINDSKY45's avatar WINDSKY45 Committed by GitHub
Browse files

[Enhancement] Add type hints in mmcv/ops/focal_loss.py. (#1994)



* [Enhance] Add type hint in `focal_loss.py`.

* Use torch.LongTensor as type of target.

* Fixed the missing type hint.

* Removed some unnecessary type hint.

* Add missing type hint and fix the type hint of `target`.

* Fix the format of typehint.

* minor refinement

* minor fix
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent f51bcf50
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
...@@ -15,7 +17,9 @@ ext_module = ext_loader.load_ext('_ext', [ ...@@ -15,7 +17,9 @@ ext_module = ext_loader.load_ext('_ext', [
class SigmoidFocalLossFunction(Function): class SigmoidFocalLossFunction(Function):
@staticmethod @staticmethod
def symbolic(g, input, target, gamma, alpha, weight, reduction): def symbolic(g, input: torch.Tensor, target: torch.LongTensor,
gamma: float, alpha: float, weight: torch.Tensor,
reduction: str):
return g.op( return g.op(
'mmcv::MMCVSigmoidFocalLoss', 'mmcv::MMCVSigmoidFocalLoss',
input, input,
...@@ -27,12 +31,12 @@ class SigmoidFocalLossFunction(Function): ...@@ -27,12 +31,12 @@ class SigmoidFocalLossFunction(Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
input, input: torch.Tensor,
target, target: Union[torch.LongTensor, torch.cuda.LongTensor],
gamma=2.0, gamma: float = 2.0,
alpha=0.25, alpha: float = 0.25,
weight=None, weight: Optional[torch.Tensor] = None,
reduction='mean'): reduction: str = 'mean') -> torch.Tensor:
assert isinstance( assert isinstance(
target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor)) target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor))
...@@ -64,7 +68,7 @@ class SigmoidFocalLossFunction(Function): ...@@ -64,7 +68,7 @@ class SigmoidFocalLossFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output: torch.Tensor) -> tuple:
input, target, weight = ctx.saved_tensors input, target, weight = ctx.saved_tensors
grad_input = input.new_zeros(input.size()) grad_input = input.new_zeros(input.size())
...@@ -88,14 +92,22 @@ sigmoid_focal_loss = SigmoidFocalLossFunction.apply ...@@ -88,14 +92,22 @@ sigmoid_focal_loss = SigmoidFocalLossFunction.apply
class SigmoidFocalLoss(nn.Module): class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'): def __init__(self,
gamma: float,
alpha: float,
weight: Optional[torch.Tensor] = None,
reduction: str = 'mean'):
super().__init__() super().__init__()
self.gamma = gamma self.gamma = gamma
self.alpha = alpha self.alpha = alpha
self.register_buffer('weight', weight) self.register_buffer('weight', weight)
self.reduction = reduction self.reduction = reduction
def forward(self, input, target): def forward(
self,
input: torch.Tensor,
target: Union[torch.LongTensor, torch.cuda.LongTensor],
) -> torch.Tensor:
return sigmoid_focal_loss(input, target, self.gamma, self.alpha, return sigmoid_focal_loss(input, target, self.gamma, self.alpha,
self.weight, self.reduction) self.weight, self.reduction)
...@@ -110,7 +122,9 @@ class SigmoidFocalLoss(nn.Module): ...@@ -110,7 +122,9 @@ class SigmoidFocalLoss(nn.Module):
class SoftmaxFocalLossFunction(Function): class SoftmaxFocalLossFunction(Function):
@staticmethod @staticmethod
def symbolic(g, input, target, gamma, alpha, weight, reduction): def symbolic(g, input: torch.Tensor, target: torch.LongTensor,
gamma: float, alpha: float, weight: torch.Tensor,
reduction: str):
return g.op( return g.op(
'mmcv::MMCVSoftmaxFocalLoss', 'mmcv::MMCVSoftmaxFocalLoss',
input, input,
...@@ -122,12 +136,12 @@ class SoftmaxFocalLossFunction(Function): ...@@ -122,12 +136,12 @@ class SoftmaxFocalLossFunction(Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
input, input: torch.Tensor,
target, target: Union[torch.LongTensor, torch.cuda.LongTensor],
gamma=2.0, gamma: float = 2.0,
alpha=0.25, alpha: float = 0.25,
weight=None, weight: Optional[torch.Tensor] = None,
reduction='mean'): reduction='mean') -> torch.Tensor:
assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
assert input.dim() == 2 assert input.dim() == 2
...@@ -169,7 +183,7 @@ class SoftmaxFocalLossFunction(Function): ...@@ -169,7 +183,7 @@ class SoftmaxFocalLossFunction(Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output: torch.Tensor) -> tuple:
input_softmax, target, weight = ctx.saved_tensors input_softmax, target, weight = ctx.saved_tensors
buff = input_softmax.new_zeros(input_softmax.size(0)) buff = input_softmax.new_zeros(input_softmax.size(0))
grad_input = input_softmax.new_zeros(input_softmax.size()) grad_input = input_softmax.new_zeros(input_softmax.size())
...@@ -194,14 +208,22 @@ softmax_focal_loss = SoftmaxFocalLossFunction.apply ...@@ -194,14 +208,22 @@ softmax_focal_loss = SoftmaxFocalLossFunction.apply
class SoftmaxFocalLoss(nn.Module): class SoftmaxFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'): def __init__(self,
gamma: float,
alpha: float,
weight: Optional[torch.Tensor] = None,
reduction: str = 'mean'):
super().__init__() super().__init__()
self.gamma = gamma self.gamma = gamma
self.alpha = alpha self.alpha = alpha
self.register_buffer('weight', weight) self.register_buffer('weight', weight)
self.reduction = reduction self.reduction = reduction
def forward(self, input, target): def forward(
self,
input: torch.Tensor,
target: Union[torch.LongTensor, torch.cuda.LongTensor],
) -> torch.Tensor:
return softmax_focal_loss(input, target, self.gamma, self.alpha, return softmax_focal_loss(input, target, self.gamma, self.alpha,
self.weight, self.reduction) self.weight, self.reduction)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment