Commit 13dffa31 authored by Gunho Park's avatar Gunho Park
Browse files

Internal change

parent de91fd22
...@@ -52,6 +52,7 @@ class BASNetLoss: ...@@ -52,6 +52,7 @@ class BASNetLoss:
total_iou_loss = tf.math.add_n(iou_losses) total_iou_loss = tf.math.add_n(iou_losses)
total_loss = total_bce_loss + total_ssim_loss + total_iou_loss total_loss = total_bce_loss + total_ssim_loss + total_iou_loss
total_loss = total_loss / len(levels)
return total_loss return total_loss
......
...@@ -30,7 +30,7 @@ class BASNetModel(tf.keras.Model): ...@@ -30,7 +30,7 @@ class BASNetModel(tf.keras.Model):
def __init__(self, def __init__(self,
backbone, backbone,
decoder, decoder,
refinement, refinement=None,
**kwargs): **kwargs):
"""BASNet initialization function. """BASNet initialization function.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment