Commit a2079427 authored by Kai Chen's avatar Kai Chen
Browse files

bug fix

parent 13dec7ea
...@@ -31,7 +31,7 @@ def sigmoid_focal_loss(pred, ...@@ -31,7 +31,7 @@ def sigmoid_focal_loss(pred,
gamma=2.0, gamma=2.0,
alpha=0.25, alpha=0.25,
reduction='elementwise_mean'): reduction='elementwise_mean'):
pred_sigmoid = pred.sigmoid().detach() pred_sigmoid = pred.sigmoid()
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma) weight = weight * pt.pow(gamma)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment