Unverified Commit 1ac84b07 authored by Bin Lu's avatar Bin Lu Committed by GitHub
Browse files

Update ace_loss.py

parent c9d32d29
......@@ -32,6 +32,7 @@ class ACELoss(nn.Layer):
def __call__(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
B, N = predicts.shape[:2]
div = paddle.to_tensor([N]).astype('float32')
......@@ -42,9 +43,7 @@ class ACELoss(nn.Layer):
length = batch[2].astype("float32")
batch = batch[3].astype("float32")
batch[:, 0] = paddle.subtract(div, length)
batch = paddle.divide(batch, div)
loss = self.loss_func(aggregation_preds, batch)
return {"loss_ace": loss}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment