Unverified Commit 85c0f85b authored by WINDSKY45's avatar WINDSKY45 Committed by GitHub
Browse files

[Enhancement] Add type hints in mmcv/ops/focal_loss.py. (#1994)



* [Enhance] Add type hint in `focal_loss.py`.

* Use torch.LongTensor as type of target.

* Fixed the missing type hint.

* Removed some unnecessary type hint.

* Add missing type hint and fix the type hint of `target`.

* Fix the format of typehint.

* minor refinement

* minor fix
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent f51bcf50
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import torch
import torch.nn as nn
from torch.autograd import Function
......@@ -15,7 +17,9 @@ ext_module = ext_loader.load_ext('_ext', [
class SigmoidFocalLossFunction(Function):
@staticmethod
def symbolic(g, input, target, gamma, alpha, weight, reduction):
def symbolic(g, input: torch.Tensor, target: torch.LongTensor,
gamma: float, alpha: float, weight: torch.Tensor,
reduction: str):
return g.op(
'mmcv::MMCVSigmoidFocalLoss',
input,
......@@ -27,12 +31,12 @@ class SigmoidFocalLossFunction(Function):
@staticmethod
def forward(ctx,
input,
target,
gamma=2.0,
alpha=0.25,
weight=None,
reduction='mean'):
input: torch.Tensor,
target: Union[torch.LongTensor, torch.cuda.LongTensor],
gamma: float = 2.0,
alpha: float = 0.25,
weight: Optional[torch.Tensor] = None,
reduction: str = 'mean') -> torch.Tensor:
assert isinstance(
target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor))
......@@ -64,7 +68,7 @@ class SigmoidFocalLossFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor) -> tuple:
input, target, weight = ctx.saved_tensors
grad_input = input.new_zeros(input.size())
......@@ -88,14 +92,22 @@ sigmoid_focal_loss = SigmoidFocalLossFunction.apply
class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
def __init__(self,
gamma: float,
alpha: float,
weight: Optional[torch.Tensor] = None,
reduction: str = 'mean'):
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.register_buffer('weight', weight)
self.reduction = reduction
def forward(self, input, target):
def forward(
self,
input: torch.Tensor,
target: Union[torch.LongTensor, torch.cuda.LongTensor],
) -> torch.Tensor:
return sigmoid_focal_loss(input, target, self.gamma, self.alpha,
self.weight, self.reduction)
......@@ -110,7 +122,9 @@ class SigmoidFocalLoss(nn.Module):
class SoftmaxFocalLossFunction(Function):
@staticmethod
def symbolic(g, input, target, gamma, alpha, weight, reduction):
def symbolic(g, input: torch.Tensor, target: torch.LongTensor,
gamma: float, alpha: float, weight: torch.Tensor,
reduction: str):
return g.op(
'mmcv::MMCVSoftmaxFocalLoss',
input,
......@@ -122,12 +136,12 @@ class SoftmaxFocalLossFunction(Function):
@staticmethod
def forward(ctx,
input,
target,
gamma=2.0,
alpha=0.25,
weight=None,
reduction='mean'):
input: torch.Tensor,
target: Union[torch.LongTensor, torch.cuda.LongTensor],
gamma: float = 2.0,
alpha: float = 0.25,
weight: Optional[torch.Tensor] = None,
reduction='mean') -> torch.Tensor:
assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
assert input.dim() == 2
......@@ -169,7 +183,7 @@ class SoftmaxFocalLossFunction(Function):
return output
@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor) -> tuple:
input_softmax, target, weight = ctx.saved_tensors
buff = input_softmax.new_zeros(input_softmax.size(0))
grad_input = input_softmax.new_zeros(input_softmax.size())
......@@ -194,14 +208,22 @@ softmax_focal_loss = SoftmaxFocalLossFunction.apply
class SoftmaxFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
def __init__(self,
gamma: float,
alpha: float,
weight: Optional[torch.Tensor] = None,
reduction: str = 'mean'):
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.register_buffer('weight', weight)
self.reduction = reduction
def forward(self, input, target):
def forward(
self,
input: torch.Tensor,
target: Union[torch.LongTensor, torch.cuda.LongTensor],
) -> torch.Tensor:
return softmax_focal_loss(input, target, self.gamma, self.alpha,
self.weight, self.reduction)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment