Commit b178d772 authored by Kai Chen's avatar Kai Chen
Browse files

fix focal loss

parent 3f412c39
...@@ -30,13 +30,20 @@ def sigmoid_focal_loss(pred, ...@@ -30,13 +30,20 @@ def sigmoid_focal_loss(pred,
weight, weight,
gamma=2.0, gamma=2.0,
alpha=0.25, alpha=0.25,
reduction='elementwise_mean'): reduction='mean'):
pred_sigmoid = pred.sigmoid() pred_sigmoid = pred.sigmoid()
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma) weight = weight * pt.pow(gamma)
return F.binary_cross_entropy_with_logits( loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
pred, target, weight, reduction=reduction) reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weighted_sigmoid_focal_loss(pred, def weighted_sigmoid_focal_loss(pred,
...@@ -58,22 +65,22 @@ def mask_cross_entropy(pred, target, label): ...@@ -58,22 +65,22 @@ def mask_cross_entropy(pred, target, label):
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1) pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits( return F.binary_cross_entropy_with_logits(
pred_slice, target, reduction='elementwise_mean')[None] pred_slice, target, reduction='mean')[None]
def smooth_l1_loss(pred, target, beta=1.0, reduction='elementwise_mean'): def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'):
assert beta > 0 assert beta > 0
assert pred.size() == target.size() and target.numel() > 0 assert pred.size() == target.size() and target.numel() > 0
diff = torch.abs(pred - target) diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta, loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta) diff - 0.5 * beta)
reduction = F._Reduction.get_enum(reduction) reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2 # none: 0, mean:1, sum: 2
if reduction == 0: if reduction_enum == 0:
return loss return loss
elif reduction == 1: elif reduction_enum == 1:
return loss.sum() / pred.numel() return loss.sum() / pred.numel()
elif reduction == 2: elif reduction_enum == 2:
return loss.sum() return loss.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment