Commit f724f9ac authored by qiang chen's avatar qiang chen Committed by Kai Chen
Browse files

fix potential bug in binary_cross_entropy (#846)

parent 32df98e9
...@@ -29,14 +29,13 @@ def binary_cross_entropy(pred, ...@@ -29,14 +29,13 @@ def binary_cross_entropy(pred,
if pred.dim() != label.dim(): if pred.dim() != label.dim():
label, weight = _expand_binary_labels(label, weight, pred.size(-1)) label, weight = _expand_binary_labels(label, weight, pred.size(-1))
# element-wise losses # weighted element-wise losses
if weight is not None: if weight is not None:
weight = weight.float() weight = weight.float()
loss = F.binary_cross_entropy_with_logits( loss = F.binary_cross_entropy_with_logits(
pred, label.float(), weight, reduction='none') pred, label.float(), weight, reduction='none')
# apply weights and do the reduction # do the reduction for the weighted loss
loss = weight_reduce_loss( loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment