Commit dabcbc97 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

darknet loss functionality update

parent 1654fbea
...@@ -349,16 +349,16 @@ class DarknetLoss(YoloLossBase): ...@@ -349,16 +349,16 @@ class DarknetLoss(YoloLossBase):
tf.cast(true_class, tf.int32), tf.cast(true_class, tf.int32),
depth=tf.shape(pred_class)[-1], depth=tf.shape(pred_class)[-1],
dtype=pred_class.dtype) dtype=pred_class.dtype)
true_classes = tf.stop_gradient(loss_utils.apply_mask(ind_mask, true_class)) true_class = tf.stop_gradient(loss_utils.apply_mask(ind_mask, true_class))
# Reorganize the one hot class list as a grid. # Reorganize the one hot class list as a grid.
true_class = loss_utils.build_grid( true_class_grid = loss_utils.build_grid(
inds, true_classes, pred_class, ind_mask, update=False) inds, true_class, pred_class, ind_mask, update=False)
true_class = tf.stop_gradient(true_class) true_class_grid = tf.stop_gradient(true_class_grid)
# Use the class mask to find the number of objects located in # Use the class mask to find the number of objects located in
# each predicted grid cell/pixel. # each predicted grid cell/pixel.
counts = true_class counts = true_class_grid
counts = tf.reduce_sum(counts, axis=-1, keepdims=True) counts = tf.reduce_sum(counts, axis=-1, keepdims=True)
reps = tf.gather_nd(counts, inds, batch_dims=1) reps = tf.gather_nd(counts, inds, batch_dims=1)
reps = tf.squeeze(reps, axis=-1) reps = tf.squeeze(reps, axis=-1)
...@@ -372,19 +372,21 @@ class DarknetLoss(YoloLossBase): ...@@ -372,19 +372,21 @@ class DarknetLoss(YoloLossBase):
box_loss = math_ops.divide_no_nan(box_loss, reps) box_loss = math_ops.divide_no_nan(box_loss, reps)
box_loss = tf.cast(tf.reduce_sum(box_loss, axis=1), dtype=y_pred.dtype) box_loss = tf.cast(tf.reduce_sum(box_loss, axis=1), dtype=y_pred.dtype)
if self._update_on_repeat:
# Compute the sigmoid binary cross entropy for the class maps. # Compute the sigmoid binary cross entropy for the class maps.
class_loss = tf.reduce_mean( class_loss = tf.reduce_mean(
loss_utils.sigmoid_bce( loss_utils.sigmoid_bce(
tf.expand_dims(true_class, axis=-1), tf.expand_dims(true_class_grid, axis=-1),
tf.expand_dims(pred_class, axis=-1), self._label_smoothing), tf.expand_dims(pred_class, axis=-1), self._label_smoothing),
axis=-1) axis=-1)
# Apply normalization to the class losses. # Apply normalization to the class losses.
if self._cls_normalizer < 1.0: if self._cls_normalizer < 1.0:
# Build a mask based on the true class locations. # Build a mask based on the true class locations.
cls_norm_mask = true_class cls_norm_mask = true_class_grid
# Apply the classes weight to class indexes were one_hot is one. # Apply the classes weight to class indexes were one_hot is one.
class_loss *= ((1 - cls_norm_mask) + cls_norm_mask * self._cls_normalizer) class_loss *= (
(1 - cls_norm_mask) + cls_norm_mask * self._cls_normalizer)
# Mask to the class loss and compute the sum over all the objects. # Mask to the class loss and compute the sum over all the objects.
class_loss = tf.reduce_sum(class_loss, axis=-1) class_loss = tf.reduce_sum(class_loss, axis=-1)
...@@ -392,6 +394,19 @@ class DarknetLoss(YoloLossBase): ...@@ -392,6 +394,19 @@ class DarknetLoss(YoloLossBase):
class_loss = math_ops.rm_nan_inf(class_loss, val=0.0) class_loss = math_ops.rm_nan_inf(class_loss, val=0.0)
class_loss = tf.cast( class_loss = tf.cast(
tf.reduce_sum(class_loss, axis=(1, 2, 3)), dtype=y_pred.dtype) tf.reduce_sum(class_loss, axis=(1, 2, 3)), dtype=y_pred.dtype)
else:
pred_class = loss_utils.apply_mask(
ind_mask, tf.gather_nd(pred_class, inds, batch_dims=1))
class_loss = tf.keras.losses.binary_crossentropy(
tf.expand_dims(true_class, axis = -1),
tf.expand_dims(pred_class, axis = -1),
label_smoothing=self._label_smoothing,
from_logits=True)
class_loss = loss_utils.apply_mask(ind_mask, class_loss)
class_loss = math_ops.divide_no_nan(
class_loss, tf.expand_dims(reps, axis = -1))
class_loss = tf.cast(tf.reduce_sum(
class_loss, axis=(1, 2)), dtype=y_pred.dtype)
# Compute the sigmoid binary cross entropy for the confidence maps. # Compute the sigmoid binary cross entropy for the confidence maps.
bce = tf.reduce_mean( bce = tf.reduce_mean(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment