Unverified Commit b6ab6563 authored by Abhijit Deo's avatar Abhijit Deo Committed by GitHub
Browse files

Focal Loss Documentation enhanced (#5799)



* Update focal_loss.py

updated docstring for the `sigmoid_focal_loss`

* edited docstring of focal_loss

* Update focal_loss.py

* fix doc style.

* formatting

* formatting

* minor edits

* move attribution link
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 6a53c9af
......@@ -10,28 +10,28 @@ def sigmoid_focal_loss(
alpha: float = 0.25,
gamma: float = 2,
reduction: str = "none",
):
) -> torch.Tensor:
"""
Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
inputs (Tensor): A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples or -1 for ignore. Default = 0.25
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
alpha (float): Weighting factor in range (0,1) to balance
positive vs negative examples or -1 for ignore. Default: ``0.25``.
gamma (float): Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples. Default: ``2``.
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
``'none'``: No reduction will be applied to the output.
``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``.
Returns:
Loss tensor with the reduction option applied.
"""
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(sigmoid_focal_loss)
p = torch.sigmoid(inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment