Unverified Commit 6ab9251f authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

compute `top_k` loss per sample

parent 054d11f5
......@@ -76,18 +76,38 @@ class WeightedBootstrappedCrossEntropyLoss:
if self._top_k_percent_pixels >= 1.0:
loss = tf.reduce_sum(cross_entropy_loss) / normalizer
else:
cross_entropy_loss = tf.reshape(cross_entropy_loss, shape=[-1])
top_k_pixels = tf.cast(
self._top_k_percent_pixels *
tf.cast(tf.size(cross_entropy_loss), tf.float32), tf.int32)
top_k_losses, _ = tf.math.top_k(
cross_entropy_loss, k=top_k_pixels, sorted=True)
normalizer = tf.reduce_sum(
tf.cast(tf.not_equal(top_k_losses, 0.0), tf.float32)) + EPSILON
loss = tf.reduce_sum(top_k_losses) / normalizer
loss = self._compute_top_k_loss(cross_entropy_loss)
return loss
def _compute_top_k_loss(self, loss):
batch_size = tf.shape(loss)[0]
loss = tf.reshape(loss, shape=[batch_size, -1])
top_k_pixels = tf.cast(
self._top_k_percent_pixels *
tf.cast(tf.shape(loss)[-1], dtype=tf.float32),
dtype=tf.int32)
# shape: [batch_size, top_k_pixels]
per_sample_top_k_loss = tf.map_fn(
fn=lambda x: tf.nn.top_k(x, k=top_k_pixels, sorted=False)[0],
elems=loss,
parallel_iterations=32,
fn_output_signature=tf.float32)
# shape: [batch_size]
per_sample_normalizer = tf.reduce_sum(
tf.cast(
tf.not_equal(per_sample_top_k_loss, 0.0),
dtype=tf.float32),
axis=-1) + EPSILON
per_sample_normalized_loss = tf.reduce_sum(
per_sample_top_k_loss, axis=-1) / per_sample_normalizer
normalized_loss = tf_utils.safe_mean(per_sample_normalized_loss)
return normalized_loss
class CenterHeatmapLoss:
def __init__(self):
self._loss_fn = tf.losses.mean_squared_error
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment