import torch import torch.nn.functional as F def sigmoid_focal_loss( pred, target, weight=1.0, gamma=2.0, alpha=0.25, reduction='mean'): pred_sigmoid = pred.sigmoid() target = target.type_as(pred) pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) weight = (alpha * target + (1 - alpha) * (1 - target)) * weight weight = weight * pt.pow(gamma) loss = F.binary_cross_entropy_with_logits( pred, target, reduction='none') * weight reduction_enum = F._Reduction.get_enum(reduction) # none: 0, mean:1, sum: 2 if reduction_enum == 0: return loss elif reduction_enum == 1: return loss.mean() elif reduction_enum == 2: return loss.sum()