Commit b178d772 authored by Kai Chen's avatar Kai Chen
Browse files

fix focal loss

parent 3f412c39
......@@ -30,13 +30,20 @@ def sigmoid_focal_loss(pred,
weight,
gamma=2.0,
alpha=0.25,
reduction='elementwise_mean'):
reduction='mean'):
pred_sigmoid = pred.sigmoid()
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma)
return F.binary_cross_entropy_with_logits(
pred, target, weight, reduction=reduction)
loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weighted_sigmoid_focal_loss(pred,
......@@ -58,22 +65,22 @@ def mask_cross_entropy(pred, target, label):
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, reduction='elementwise_mean')[None]
pred_slice, target, reduction='mean')[None]
def smooth_l1_loss(pred, target, beta=1.0, reduction='elementwise_mean'):
def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'):
assert beta > 0
assert pred.size() == target.size() and target.numel() > 0
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
reduction = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction == 0:
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction == 1:
elif reduction_enum == 1:
return loss.sum() / pred.numel()
elif reduction == 2:
elif reduction_enum == 2:
return loss.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment