Unverified Commit 1ac84b07 authored by Bin Lu's avatar Bin Lu Committed by GitHub
Browse files

Update ace_loss.py

parent c9d32d29
...@@ -32,6 +32,7 @@ class ACELoss(nn.Layer): ...@@ -32,6 +32,7 @@ class ACELoss(nn.Layer):
def __call__(self, predicts, batch): def __call__(self, predicts, batch):
if isinstance(predicts, (list, tuple)): if isinstance(predicts, (list, tuple)):
predicts = predicts[-1] predicts = predicts[-1]
B, N = predicts.shape[:2] B, N = predicts.shape[:2]
div = paddle.to_tensor([N]).astype('float32') div = paddle.to_tensor([N]).astype('float32')
...@@ -42,9 +43,7 @@ class ACELoss(nn.Layer): ...@@ -42,9 +43,7 @@ class ACELoss(nn.Layer):
length = batch[2].astype("float32") length = batch[2].astype("float32")
batch = batch[3].astype("float32") batch = batch[3].astype("float32")
batch[:, 0] = paddle.subtract(div, length) batch[:, 0] = paddle.subtract(div, length)
batch = paddle.divide(batch, div) batch = paddle.divide(batch, div)
loss = self.loss_func(aggregation_preds, batch) loss = self.loss_func(aggregation_preds, batch)
return {"loss_ace": loss} return {"loss_ace": loss}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment