Unverified Commit b6ab6563 authored by Abhijit Deo's avatar Abhijit Deo Committed by GitHub
Browse files

Focal Loss Documentation enhanced (#5799)



* Update focal_loss.py

updated docstring for the `sigmoid_focal_loss`

* edited docstring of focal_loss

* Update focal_loss.py

* fix doc style.

* formatting

* formatting

* minor edits

* move attribution link
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 6a53c9af
...@@ -10,28 +10,28 @@ def sigmoid_focal_loss( ...@@ -10,28 +10,28 @@ def sigmoid_focal_loss(
alpha: float = 0.25, alpha: float = 0.25,
gamma: float = 2, gamma: float = 2,
reduction: str = "none", reduction: str = "none",
): ) -> torch.Tensor:
""" """
Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args: Args:
inputs: A float tensor of arbitrary shape. inputs (Tensor): A float tensor of arbitrary shape.
The predictions for each example. The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs classification label for each element in inputs
(0 for the negative class and 1 for the positive class). (0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance alpha (float): Weighting factor in range (0,1) to balance
positive vs negative examples or -1 for ignore. Default = 0.25 positive vs negative examples or -1 for ignore. Default: ``0.25``.
gamma: Exponent of the modulating factor (1 - p_t) to gamma (float): Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples. balance easy vs hard examples. Default: ``2``.
reduction: 'none' | 'mean' | 'sum' reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
'none': No reduction will be applied to the output. ``'none'``: No reduction will be applied to the output.
'mean': The output will be averaged. ``'mean'``: The output will be averaged.
'sum': The output will be summed. ``'sum'``: The output will be summed. Default: ``'none'``.
Returns: Returns:
Loss tensor with the reduction option applied. Loss tensor with the reduction option applied.
""" """
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(sigmoid_focal_loss) _log_api_usage_once(sigmoid_focal_loss)
p = torch.sigmoid(inputs) p = torch.sigmoid(inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment