Commit c4408812 authored by Kai Chen's avatar Kai Chen
Browse files

bug fix for focal loss

parent ddda131f
......@@ -35,7 +35,8 @@ def sigmoid_focal_loss(pred,
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * weight
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment